@@ -44,6 +44,7 @@ import (
4444 "px.dev/pixie/src/api/proto/uuidpb"
4545 "px.dev/pixie/src/api/proto/vizierpb"
4646 "px.dev/pixie/src/cloud/api/ptproxy"
47+ "px.dev/pixie/src/cloud/scriptmgr/scriptmgrpb"
4748 "px.dev/pixie/src/cloud/shared/vzshard"
4849 "px.dev/pixie/src/shared/cvmsgspb"
4950 "px.dev/pixie/src/shared/services/env"
@@ -65,15 +66,15 @@ type testState struct {
6566 conn * grpc.ClientConn
6667}
6768
68- func createTestState (t * testing.T ) (* testState , func (t * testing.T )) {
69+ func createTestState (t * testing.T , disableScriptModification bool ) (* testState , func (t * testing.T )) {
6970 lis := bufconn .Listen (bufSize )
7071 env := env .New ("withpixie.ai" )
7172 s := server .CreateGRPCServer (env , & server.GRPCServerOptions {})
7273
7374 nc , natsCleanup := testingutils .MustStartTestNATS (t )
7475
75- vizierpb .RegisterVizierServiceServer (s , ptproxy .NewVizierPassThroughProxy (nc , & fakeVzMgr {}))
76- vizierpb .RegisterVizierDebugServiceServer (s , ptproxy .NewVizierPassThroughProxy (nc , & fakeVzMgr {}))
76+ vizierpb .RegisterVizierServiceServer (s , ptproxy .NewVizierPassThroughProxy (nc , & fakeVzMgr {}, & fakeScriptMgr {}, disableScriptModification ))
77+ vizierpb .RegisterVizierDebugServiceServer (s , ptproxy .NewVizierPassThroughProxy (nc , & fakeVzMgr {}, & fakeScriptMgr {}, disableScriptModification ))
7778
7879 eg := errgroup.Group {}
7980 eg .Go (func () error { return s .Serve (lis ) })
@@ -112,7 +113,7 @@ func createDialer(lis *bufconn.Listener) func(ctx context.Context, url string) (
112113func TestVizierPassThroughProxy_ExecuteScript (t * testing.T ) {
113114 viper .Set ("jwt_signing_key" , "the-key" )
114115
115- ts , cleanup := createTestState (t )
116+ ts , cleanup := createTestState (t , false )
116117 defer cleanup (t )
117118
118119 client := vizierpb .NewVizierServiceClient (ts .conn )
@@ -283,10 +284,141 @@ func TestVizierPassThroughProxy_ExecuteScript(t *testing.T) {
283284 }
284285}
285286
287+ func TestVizierPassThroughProxy_ExecuteScriptWithScriptModificationEnabled (t * testing.T ) {
288+ viper .Set ("jwt_signing_key" , "the-key" )
289+
290+ ts , cleanup := createTestState (t , true )
291+ defer cleanup (t )
292+
293+ client := vizierpb .NewVizierServiceClient (ts .conn )
294+ validTestToken := testingutils .GenerateTestJWTToken (t , viper .GetString ("jwt_signing_key" ))
295+
296+ testCases := []struct {
297+ name string
298+
299+ clusterID string
300+ authToken string
301+ pxlString string
302+ respFromVizier []* cvmsgspb.V2CAPIStreamResponse
303+
304+ expGRPCError error
305+ expGRPCResponses []* vizierpb.ExecuteScriptResponse
306+ }{
307+ {
308+ name : "Request with modified pxl script" ,
309+
310+ clusterID : "00000000-1111-2222-2222-333333333333" ,
311+ pxlString : "import pxl" ,
312+ authToken : validTestToken ,
313+ expGRPCError : status .Error (codes .InvalidArgument , "Script modification has been disabled" ),
314+ expGRPCResponses : nil ,
315+ },
316+ {
317+ name : "Request with not modified pxl script" ,
318+
319+ clusterID : "00000000-1111-2222-2222-333333333333" ,
320+ pxlString : "liveview1 pxl" ,
321+ authToken : validTestToken ,
322+ expGRPCError : nil ,
323+ expGRPCResponses : []* vizierpb.ExecuteScriptResponse {
324+ {
325+ QueryID : "abc" ,
326+ },
327+ {
328+ QueryID : "abc" ,
329+ },
330+ },
331+ respFromVizier : []* cvmsgspb.V2CAPIStreamResponse {
332+ {
333+ Msg : & cvmsgspb.V2CAPIStreamResponse_ExecResp {ExecResp : & vizierpb.ExecuteScriptResponse {QueryID : "abc" }},
334+ },
335+ {
336+ Msg : & cvmsgspb.V2CAPIStreamResponse_ExecResp {ExecResp : & vizierpb.ExecuteScriptResponse {QueryID : "abc" }},
337+ },
338+ },
339+ },
340+ }
341+
342+ for _ , tc := range testCases {
343+ t .Run (tc .name , func (t * testing.T ) {
344+ ctx := context .Background ()
345+ if len (tc .authToken ) > 0 {
346+ ctx = metadata .AppendToOutgoingContext (ctx , "authorization" ,
347+ fmt .Sprintf ("bearer %s" , tc .authToken ))
348+ }
349+
350+ ctx , cancel := context .WithCancel (ctx )
351+ defer cancel ()
352+ resp , err := client .ExecuteScript (ctx ,
353+ & vizierpb.ExecuteScriptRequest {ClusterID : tc .clusterID , QueryStr : tc .pxlString })
354+ require .NoError (t , err )
355+ fv := newFakeVizier (t , uuid .FromStringOrNil (tc .clusterID ), ts .nc )
356+ fv .Run (t , tc .respFromVizier )
357+ defer fv .Stop ()
358+
359+ grpcDataCh := make (chan * vizierpb.ExecuteScriptResponse )
360+ var gotReadErr error
361+ var eg errgroup.Group
362+ eg .Go (func () error {
363+ defer close (grpcDataCh )
364+ for {
365+ d , err := resp .Recv ()
366+ if err != nil && err != io .EOF {
367+ gotReadErr = err
368+ }
369+ if err == io .EOF {
370+ return nil
371+ }
372+ if d == nil {
373+ return nil
374+ }
375+ grpcDataCh <- d
376+ }
377+ })
378+
379+ var responses []* vizierpb.ExecuteScriptResponse
380+ eg .Go (func () error {
381+ timeout := time .NewTimer (defaultTimeout )
382+ defer timeout .Stop ()
383+
384+ for {
385+ select {
386+ case <- resp .Context ().Done ():
387+ return nil
388+ case <- timeout .C :
389+ return fmt .Errorf ("timeout waiting for data on grpc channel" )
390+ case msg := <- grpcDataCh :
391+
392+ if msg == nil {
393+ return nil
394+ }
395+ responses = append (responses , msg )
396+ }
397+ }
398+ })
399+
400+ err = eg .Wait ()
401+ if tc .expGRPCError != nil {
402+ if gotReadErr == nil {
403+ t .Fatal ("Expected to get GRPC error" )
404+ }
405+ assert .Equal (t , status .Code (tc .expGRPCError ), status .Code (gotReadErr ))
406+ }
407+ if tc .expGRPCResponses == nil {
408+ if len (responses ) != 0 {
409+ t .Fatal ("Expected to get no responses" )
410+ }
411+ } else {
412+ assert .Equal (t , tc .expGRPCResponses , responses )
413+ }
414+ })
415+ }
416+ }
417+
286418func TestVizierPassThroughProxy_HealthCheck (t * testing.T ) {
287419 viper .Set ("jwt_signing_key" , "the-key" )
288420
289- ts , cleanup := createTestState (t )
421+ ts , cleanup := createTestState (t , false )
290422 defer cleanup (t )
291423
292424 client := vizierpb .NewVizierServiceClient (ts .conn )
@@ -463,7 +595,7 @@ func TestVizierPassThroughProxy_HealthCheck(t *testing.T) {
463595func TestVizierPassThroughProxy_DebugLog (t * testing.T ) {
464596 viper .Set ("jwt_signing_key" , "the-key" )
465597
466- ts , cleanup := createTestState (t )
598+ ts , cleanup := createTestState (t , false )
467599 defer cleanup (t )
468600
469601 client := vizierpb .NewVizierDebugServiceClient (ts .conn )
@@ -582,7 +714,7 @@ func TestVizierPassThroughProxy_DebugLog(t *testing.T) {
582714func TestVizierPassThroughProxy_DebugPods (t * testing.T ) {
583715 viper .Set ("jwt_signing_key" , "the-key" )
584716
585- ts , cleanup := createTestState (t )
717+ ts , cleanup := createTestState (t , false )
586718 defer cleanup (t )
587719
588720 client := vizierpb .NewVizierDebugServiceClient (ts .conn )
@@ -719,6 +851,20 @@ func TestVizierPassThroughProxy_DebugPods(t *testing.T) {
719851 }
720852}
721853
854+ type fakeScriptMgr struct {}
855+
856+ func (s * fakeScriptMgr ) GetScriptByHash (ctx context.Context , req * scriptmgrpb.GetScriptByHashReq , opts ... grpc.CallOption ) (* scriptmgrpb.GetScriptByHashResp , error ) {
857+ hash := "488f131003f415a61090901c544e0ace731e8a85b12ce0aea770273d656f08e0" // sha256 of "liveview1 pxl"
858+
859+ scripts := map [string ]bool {
860+ hash : true ,
861+ }
862+ _ , ok := scripts [req .Sha256Hash ]
863+ return & scriptmgrpb.GetScriptByHashResp {
864+ Exists : ok ,
865+ }, nil
866+ }
867+
722868type fakeVzMgr struct {}
723869
724870func (v * fakeVzMgr ) GetVizierInfo (ctx context.Context , in * uuidpb.UUID , opts ... grpc.CallOption ) (* cvmsgspb.VizierInfo , error ) {
0 commit comments