Skip to content

Commit 034804f

Browse files
committed
Add configuration option to cloud api for disabling pxl script modification
Signed-off-by: Dom Del Nano <ddelnano@gmail.com>
1 parent 79e3ec3 commit 034804f

File tree

4 files changed

+217
-15
lines changed

4 files changed

+217
-15
lines changed

src/cloud/api/api_server.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ func init() {
6666

6767
pflag.String("auth_connector_name", "", "If any, the name of the auth connector to be used with Pixie")
6868
pflag.String("auth_connector_callback_url", "", "If any, the callback URL for the auth connector")
69+
pflag.Bool("disable_script_modification", false, "If script modification should be disallowed to prevent arbitrary script execution")
6970
}
7071

7172
func main() {
@@ -213,17 +214,18 @@ func main() {
213214
authServer := &controllers.AuthServer{AuthClient: ac}
214215
cloudpb.RegisterAuthServiceServer(s.GRPCServer(), authServer)
215216

216-
vpt := ptproxy.NewVizierPassThroughProxy(nc, vc)
217-
vizierpb.RegisterVizierServiceServer(s.GRPCServer(), vpt)
218-
vizierpb.RegisterVizierDebugServiceServer(s.GRPCServer(), vpt)
219-
220217
sm, err := apienv.NewScriptMgrServiceClient()
221218
if err != nil {
222219
log.WithError(err).Fatal("Failed to init scriptmgr client.")
223220
}
224221
sms := &controllers.ScriptMgrServer{ScriptMgr: sm}
225222
cloudpb.RegisterScriptMgrServer(s.GRPCServer(), sms)
226223

224+
disableScriptModification := viper.GetBool("disable_script_modification")
225+
vpt := ptproxy.NewVizierPassThroughProxy(nc, vc, sm, disableScriptModification)
226+
vizierpb.RegisterVizierServiceServer(s.GRPCServer(), vpt)
227+
vizierpb.RegisterVizierDebugServiceServer(s.GRPCServer(), vpt)
228+
227229
mdIndexName := viper.GetString("md_index_name")
228230
if mdIndexName == "" {
229231
log.Fatal("Must specify a name for the elastic index.")

src/cloud/api/ptproxy/BUILD.bazel

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,15 @@ go_library(
2828
deps = [
2929
"//src/api/proto/uuidpb:uuid_pl_go_proto",
3030
"//src/api/proto/vizierpb:vizier_pl_go_proto",
31+
"//src/cloud/scriptmgr/scriptmgrpb:service_pl_go_proto",
3132
"//src/cloud/shared/vzshard",
3233
"//src/shared/cvmsgspb:cvmsgs_pl_go_proto",
3334
"//src/shared/services/authcontext",
35+
"//src/shared/services/utils",
3436
"//src/shared/services/jwtpb:jwt_pl_go_proto",
3537
"//src/utils",
3638
"@com_github_gofrs_uuid//:uuid",
39+
"@com_github_spf13_viper//:viper",
3740
"@com_github_gogo_protobuf//proto",
3841
"@com_github_gogo_protobuf//types",
3942
"@com_github_nats_io_nats_go//:nats_go",
@@ -53,6 +56,7 @@ pl_go_test(
5356
":ptproxy",
5457
"//src/api/proto/uuidpb:uuid_pl_go_proto",
5558
"//src/api/proto/vizierpb:vizier_pl_go_proto",
59+
"//src/cloud/scriptmgr/scriptmgrpb:service_pl_go_proto",
5660
"//src/cloud/shared/vzshard",
5761
"//src/shared/cvmsgspb:cvmsgs_pl_go_proto",
5862
"//src/shared/services/env",

src/cloud/api/ptproxy/vizier_pt_proxy.go

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,21 @@ package ptproxy
2020

2121
import (
2222
"context"
23+
"crypto/sha256"
24+
"encoding/hex"
25+
"fmt"
2326

2427
"github.com/nats-io/nats.go"
28+
"github.com/spf13/viper"
2529
"google.golang.org/grpc"
30+
"google.golang.org/grpc/codes"
31+
"google.golang.org/grpc/metadata"
32+
"google.golang.org/grpc/status"
33+
jwtutils "px.dev/pixie/src/shared/services/utils"
2634

2735
"px.dev/pixie/src/api/proto/uuidpb"
2836
"px.dev/pixie/src/api/proto/vizierpb"
37+
"px.dev/pixie/src/cloud/scriptmgr/scriptmgrpb"
2938
"px.dev/pixie/src/shared/cvmsgspb"
3039
"px.dev/pixie/src/shared/services/authcontext"
3140
"px.dev/pixie/src/shared/services/jwtpb"
@@ -36,16 +45,46 @@ type vzmgrClient interface {
3645
GetVizierConnectionInfo(ctx context.Context, in *uuidpb.UUID, opts ...grpc.CallOption) (*cvmsgspb.VizierConnectionInfo, error)
3746
}
3847

48+
type scriptmgrClient interface {
49+
GetScriptByHash(ctx context.Context, req *scriptmgrpb.GetScriptByHashReq, opts ...grpc.CallOption) (*scriptmgrpb.GetScriptByHashResp, error)
50+
}
51+
3952
// VizierPassThroughProxy implements the VizierAPI and allows proxying the data to the actual
4053
// vizier cluster.
4154
type VizierPassThroughProxy struct {
42-
nc *nats.Conn
43-
vc vzmgrClient
55+
nc *nats.Conn
56+
vc vzmgrClient
57+
sm scriptmgrClient
58+
disableScriptModifiation bool
59+
}
60+
61+
// getServiceCredentials returns JWT credentials for inter-service requests.
62+
func getServiceCredentials(signingKey string) (string, error) {
63+
claims := jwtutils.GenerateJWTForService("cloud api", viper.GetString("domain_name"))
64+
return jwtutils.SignJWTClaims(claims, signingKey)
4465
}
4566

4667
// NewVizierPassThroughProxy creates a new passthrough proxy.
47-
func NewVizierPassThroughProxy(nc *nats.Conn, vc vzmgrClient) *VizierPassThroughProxy {
48-
return &VizierPassThroughProxy{nc: nc, vc: vc}
68+
func NewVizierPassThroughProxy(nc *nats.Conn, vc vzmgrClient, sm scriptmgrClient, disableScriptModifiation bool) *VizierPassThroughProxy {
69+
return &VizierPassThroughProxy{nc: nc, vc: vc, sm: sm, disableScriptModifiation: disableScriptModifiation}
70+
}
71+
72+
func (v *VizierPassThroughProxy) isScriptModified(ctx context.Context, script string) (bool, error) {
73+
hash := sha256.New()
74+
hash.Write([]byte(script))
75+
hashStr := hex.EncodeToString(hash.Sum(nil))
76+
req := &scriptmgrpb.GetScriptByHashReq{Sha256Hash: hashStr}
77+
78+
serviceAuthToken, err := getServiceCredentials(viper.GetString("jwt_signing_key"))
79+
ctx = metadata.AppendToOutgoingContext(ctx, "authorization",
80+
fmt.Sprintf("bearer %s", serviceAuthToken))
81+
82+
resp, err := v.sm.GetScriptByHash(ctx, req)
83+
84+
if err != nil {
85+
return false, err
86+
}
87+
return !resp.Exists, nil
4988
}
5089

5190
// ExecuteScript is the GRPC stream method.
@@ -55,6 +94,17 @@ func (v *VizierPassThroughProxy) ExecuteScript(req *vizierpb.ExecuteScriptReques
5594
return err
5695
}
5796
defer rp.Finish()
97+
if v.disableScriptModifiation {
98+
modified, err := v.isScriptModified(srv.Context(), req.QueryStr)
99+
if err != nil {
100+
return err
101+
}
102+
103+
if modified {
104+
return status.Error(codes.InvalidArgument, "Script modification has been disabled")
105+
}
106+
}
107+
58108
vizReq := rp.prepareVizierRequest()
59109
vizReq.Msg = &cvmsgspb.C2VAPIStreamRequest_ExecReq{ExecReq: req}
60110
if err := rp.sendMessageToVizier(vizReq); err != nil {

src/cloud/api/ptproxy/vizier_pt_proxy_test.go

Lines changed: 153 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import (
4444
"px.dev/pixie/src/api/proto/uuidpb"
4545
"px.dev/pixie/src/api/proto/vizierpb"
4646
"px.dev/pixie/src/cloud/api/ptproxy"
47+
"px.dev/pixie/src/cloud/scriptmgr/scriptmgrpb"
4748
"px.dev/pixie/src/cloud/shared/vzshard"
4849
"px.dev/pixie/src/shared/cvmsgspb"
4950
"px.dev/pixie/src/shared/services/env"
@@ -65,15 +66,15 @@ type testState struct {
6566
conn *grpc.ClientConn
6667
}
6768

68-
func createTestState(t *testing.T) (*testState, func(t *testing.T)) {
69+
func createTestState(t *testing.T, disableScriptModification bool) (*testState, func(t *testing.T)) {
6970
lis := bufconn.Listen(bufSize)
7071
env := env.New("withpixie.ai")
7172
s := server.CreateGRPCServer(env, &server.GRPCServerOptions{})
7273

7374
nc, natsCleanup := testingutils.MustStartTestNATS(t)
7475

75-
vizierpb.RegisterVizierServiceServer(s, ptproxy.NewVizierPassThroughProxy(nc, &fakeVzMgr{}))
76-
vizierpb.RegisterVizierDebugServiceServer(s, ptproxy.NewVizierPassThroughProxy(nc, &fakeVzMgr{}))
76+
vizierpb.RegisterVizierServiceServer(s, ptproxy.NewVizierPassThroughProxy(nc, &fakeVzMgr{}, &fakeScriptMgr{}, disableScriptModification))
77+
vizierpb.RegisterVizierDebugServiceServer(s, ptproxy.NewVizierPassThroughProxy(nc, &fakeVzMgr{}, &fakeScriptMgr{}, disableScriptModification))
7778

7879
eg := errgroup.Group{}
7980
eg.Go(func() error { return s.Serve(lis) })
@@ -112,7 +113,7 @@ func createDialer(lis *bufconn.Listener) func(ctx context.Context, url string) (
112113
func TestVizierPassThroughProxy_ExecuteScript(t *testing.T) {
113114
viper.Set("jwt_signing_key", "the-key")
114115

115-
ts, cleanup := createTestState(t)
116+
ts, cleanup := createTestState(t, false)
116117
defer cleanup(t)
117118

118119
client := vizierpb.NewVizierServiceClient(ts.conn)
@@ -283,10 +284,141 @@ func TestVizierPassThroughProxy_ExecuteScript(t *testing.T) {
283284
}
284285
}
285286

287+
func TestVizierPassThroughProxy_ExecuteScriptWithScriptModificationEnabled(t *testing.T) {
288+
viper.Set("jwt_signing_key", "the-key")
289+
290+
ts, cleanup := createTestState(t, true)
291+
defer cleanup(t)
292+
293+
client := vizierpb.NewVizierServiceClient(ts.conn)
294+
validTestToken := testingutils.GenerateTestJWTToken(t, viper.GetString("jwt_signing_key"))
295+
296+
testCases := []struct {
297+
name string
298+
299+
clusterID string
300+
authToken string
301+
pxlString string
302+
respFromVizier []*cvmsgspb.V2CAPIStreamResponse
303+
304+
expGRPCError error
305+
expGRPCResponses []*vizierpb.ExecuteScriptResponse
306+
}{
307+
{
308+
name: "Request with modified pxl script",
309+
310+
clusterID: "00000000-1111-2222-2222-333333333333",
311+
pxlString: "import pxl",
312+
authToken: validTestToken,
313+
expGRPCError: status.Error(codes.InvalidArgument, "Script modification has been disabled"),
314+
expGRPCResponses: nil,
315+
},
316+
{
317+
name: "Request with not modified pxl script",
318+
319+
clusterID: "00000000-1111-2222-2222-333333333333",
320+
pxlString: "liveview1 pxl",
321+
authToken: validTestToken,
322+
expGRPCError: nil,
323+
expGRPCResponses: []*vizierpb.ExecuteScriptResponse{
324+
{
325+
QueryID: "abc",
326+
},
327+
{
328+
QueryID: "abc",
329+
},
330+
},
331+
respFromVizier: []*cvmsgspb.V2CAPIStreamResponse{
332+
{
333+
Msg: &cvmsgspb.V2CAPIStreamResponse_ExecResp{ExecResp: &vizierpb.ExecuteScriptResponse{QueryID: "abc"}},
334+
},
335+
{
336+
Msg: &cvmsgspb.V2CAPIStreamResponse_ExecResp{ExecResp: &vizierpb.ExecuteScriptResponse{QueryID: "abc"}},
337+
},
338+
},
339+
},
340+
}
341+
342+
for _, tc := range testCases {
343+
t.Run(tc.name, func(t *testing.T) {
344+
ctx := context.Background()
345+
if len(tc.authToken) > 0 {
346+
ctx = metadata.AppendToOutgoingContext(ctx, "authorization",
347+
fmt.Sprintf("bearer %s", tc.authToken))
348+
}
349+
350+
ctx, cancel := context.WithCancel(ctx)
351+
defer cancel()
352+
resp, err := client.ExecuteScript(ctx,
353+
&vizierpb.ExecuteScriptRequest{ClusterID: tc.clusterID, QueryStr: tc.pxlString})
354+
require.NoError(t, err)
355+
fv := newFakeVizier(t, uuid.FromStringOrNil(tc.clusterID), ts.nc)
356+
fv.Run(t, tc.respFromVizier)
357+
defer fv.Stop()
358+
359+
grpcDataCh := make(chan *vizierpb.ExecuteScriptResponse)
360+
var gotReadErr error
361+
var eg errgroup.Group
362+
eg.Go(func() error {
363+
defer close(grpcDataCh)
364+
for {
365+
d, err := resp.Recv()
366+
if err != nil && err != io.EOF {
367+
gotReadErr = err
368+
}
369+
if err == io.EOF {
370+
return nil
371+
}
372+
if d == nil {
373+
return nil
374+
}
375+
grpcDataCh <- d
376+
}
377+
})
378+
379+
var responses []*vizierpb.ExecuteScriptResponse
380+
eg.Go(func() error {
381+
timeout := time.NewTimer(defaultTimeout)
382+
defer timeout.Stop()
383+
384+
for {
385+
select {
386+
case <-resp.Context().Done():
387+
return nil
388+
case <-timeout.C:
389+
return fmt.Errorf("timeout waiting for data on grpc channel")
390+
case msg := <-grpcDataCh:
391+
392+
if msg == nil {
393+
return nil
394+
}
395+
responses = append(responses, msg)
396+
}
397+
}
398+
})
399+
400+
err = eg.Wait()
401+
if tc.expGRPCError != nil {
402+
if gotReadErr == nil {
403+
t.Fatal("Expected to get GRPC error")
404+
}
405+
assert.Equal(t, status.Code(tc.expGRPCError), status.Code(gotReadErr))
406+
}
407+
if tc.expGRPCResponses == nil {
408+
if len(responses) != 0 {
409+
t.Fatal("Expected to get no responses")
410+
}
411+
} else {
412+
assert.Equal(t, tc.expGRPCResponses, responses)
413+
}
414+
})
415+
}
416+
}
417+
286418
func TestVizierPassThroughProxy_HealthCheck(t *testing.T) {
287419
viper.Set("jwt_signing_key", "the-key")
288420

289-
ts, cleanup := createTestState(t)
421+
ts, cleanup := createTestState(t, false)
290422
defer cleanup(t)
291423

292424
client := vizierpb.NewVizierServiceClient(ts.conn)
@@ -463,7 +595,7 @@ func TestVizierPassThroughProxy_HealthCheck(t *testing.T) {
463595
func TestVizierPassThroughProxy_DebugLog(t *testing.T) {
464596
viper.Set("jwt_signing_key", "the-key")
465597

466-
ts, cleanup := createTestState(t)
598+
ts, cleanup := createTestState(t, false)
467599
defer cleanup(t)
468600

469601
client := vizierpb.NewVizierDebugServiceClient(ts.conn)
@@ -582,7 +714,7 @@ func TestVizierPassThroughProxy_DebugLog(t *testing.T) {
582714
func TestVizierPassThroughProxy_DebugPods(t *testing.T) {
583715
viper.Set("jwt_signing_key", "the-key")
584716

585-
ts, cleanup := createTestState(t)
717+
ts, cleanup := createTestState(t, false)
586718
defer cleanup(t)
587719

588720
client := vizierpb.NewVizierDebugServiceClient(ts.conn)
@@ -719,6 +851,20 @@ func TestVizierPassThroughProxy_DebugPods(t *testing.T) {
719851
}
720852
}
721853

854+
type fakeScriptMgr struct{}
855+
856+
func (s *fakeScriptMgr) GetScriptByHash(ctx context.Context, req *scriptmgrpb.GetScriptByHashReq, opts ...grpc.CallOption) (*scriptmgrpb.GetScriptByHashResp, error) {
857+
hash := "488f131003f415a61090901c544e0ace731e8a85b12ce0aea770273d656f08e0" // sha256 of "liveview1 pxl"
858+
859+
scripts := map[string]bool{
860+
hash: true,
861+
}
862+
_, ok := scripts[req.Sha256Hash]
863+
return &scriptmgrpb.GetScriptByHashResp{
864+
Exists: ok,
865+
}, nil
866+
}
867+
722868
type fakeVzMgr struct{}
723869

724870
func (v *fakeVzMgr) GetVizierInfo(ctx context.Context, in *uuidpb.UUID, opts ...grpc.CallOption) (*cvmsgspb.VizierInfo, error) {

0 commit comments

Comments
 (0)