diff --git a/edgraph/access_ee.go b/edgraph/access_ee.go index ccb3b6b07fc..0ca838cd4dd 100644 --- a/edgraph/access_ee.go +++ b/edgraph/access_ee.go @@ -40,6 +40,11 @@ import ( "google.golang.org/grpc/status" ) +type predsAndvars struct { + preds []string + vars map[string]string +} + // Login handles login requests from clients. func (s *Server) Login(ctx context.Context, request *api.LoginRequest) (*api.Response, error) { @@ -683,14 +688,19 @@ func authorizeMutation(ctx context.Context, gmu *gql.Mutation) error { return err } -func parsePredsFromQuery(gqls []*gql.GraphQuery) []string { +func parsePredsFromQuery(gqls []*gql.GraphQuery) predsAndvars { predsMap := make(map[string]struct{}) + varsMap := make(map[string]string) for _, gq := range gqls { if gq.Func != nil { predsMap[gq.Func.Attr] = struct{}{} } - if len(gq.Attr) > 0 && gq.Attr != "uid" && gq.Attr != "expand" { + if len(gq.Var) > 0 { + varsMap[gq.Var] = gq.Attr + } + if len(gq.Attr) > 0 && gq.Attr != "uid" && gq.Attr != "expand" && gq.Attr != "val" { predsMap[gq.Attr] = struct{}{} + } for _, ord := range gq.Order { predsMap[ord.Attr] = struct{}{} @@ -701,15 +711,23 @@ func parsePredsFromQuery(gqls []*gql.GraphQuery) []string { for _, pred := range parsePredsFromFilter(gq.Filter) { predsMap[pred] = struct{}{} } - for _, childPred := range parsePredsFromQuery(gq.Children) { + childPredandVars := parsePredsFromQuery(gq.Children) + for _, childPred := range childPredandVars.preds { predsMap[childPred] = struct{}{} } + for childVar := range childPredandVars.vars { + varsMap[childVar] = childPredandVars.vars[childVar] + } } preds := make([]string, 0, len(predsMap)) for pred := range predsMap { - preds = append(preds, pred) + if _, found := varsMap[pred]; !found { + preds = append(preds, pred) + } } - return preds + + pv := predsAndvars{preds: preds, vars: varsMap} + return pv } func parsePredsFromFilter(f *gql.FilterTree) []string { @@ -756,7 +774,16 @@ func authorizeQuery(ctx context.Context, parsedReq *gql.Result, graphql bool) er var userId string var groupIds []string - preds := parsePredsFromQuery(parsedReq.Query) + predsAndvars := parsePredsFromQuery(parsedReq.Query) + preds := predsAndvars.preds + varsToPredMap := predsAndvars.vars + + // Need this to efficiently identify blocked variables from the + // list of blocked predicates + predToVarsMap := make(map[string]string) + for k, v := range varsToPredMap { + predToVarsMap[v] = k + } doAuthorizeQuery := func() (map[string]struct{}, []string, error) { userData, err := extractUserAndGroups(ctx) @@ -807,7 +834,18 @@ func authorizeQuery(ctx context.Context, parsedReq *gql.Result, graphql bool) er // In query context ~predicate and predicate are considered different. delete(blockedPreds, "~dgraph.user.group") } + + blockedVars := make(map[string]struct{}) + for predicate := range blockedPreds { + if variable, found := predToVarsMap[predicate]; found { + // Add variables to blockedPreds to delete from Query + blockedPreds[variable] = struct{}{} + // Collect blocked Variables to remove from QueryVars + blockedVars[variable] = struct{}{} + } + } parsedReq.Query = removePredsFromQuery(parsedReq.Query, blockedPreds) + parsedReq.QueryVars = removeVarsFromQueryVars(parsedReq.QueryVars, blockedVars) } for i := range parsedReq.Query { parsedReq.Query[i].AllowedPreds = allowedPreds @@ -1058,6 +1096,7 @@ func removePredsFromQuery(gqs []*gql.GraphQuery, blockedPreds map[string]struct{}) []*gql.GraphQuery { filteredGQs := gqs[:0] +L: for _, gq := range gqs { if gq.Func != nil && len(gq.Func.Attr) > 0 { if _, ok := blockedPreds[gq.Func.Attr]; ok { @@ -1068,6 +1107,15 @@ func removePredsFromQuery(gqs []*gql.GraphQuery, if _, ok := blockedPreds[gq.Attr]; ok { continue } + if gq.Attr == "val" { + // TODO (Anurag): If val supports multiple variables, this would + // need an upgrade + for _, variable := range gq.NeedsVar { + if _, ok := blockedPreds[variable.Name]; ok { + continue L + } + } + } } order := gq.Order[:0] @@ -1088,6 +1136,30 @@ func removePredsFromQuery(gqs []*gql.GraphQuery, return filteredGQs } +func removeVarsFromQueryVars(gqs []*gql.Vars, + blockedVars map[string]struct{}) []*gql.Vars { + + filteredGQs := gqs[:0] + for _, gq := range gqs { + var defines []string + var needs []string + for _, variable := range gq.Defines { + if _, ok := blockedVars[variable]; !ok { + defines = append(defines, variable) + } + } + for _, variable := range gq.Needs { + if _, ok := blockedVars[variable]; !ok { + needs = append(needs, variable) + } + } + gq.Defines = defines + gq.Needs = needs + filteredGQs = append(filteredGQs, gq) + } + return filteredGQs +} + func removeFilters(f *gql.FilterTree, blockedPreds map[string]struct{}) *gql.FilterTree { if f == nil { return nil diff --git a/ee/acl/acl_test.go b/ee/acl/acl_test.go index edbe63de85f..ac1c4a1235d 100644 --- a/ee/acl/acl_test.go +++ b/ee/acl/acl_test.go @@ -1255,6 +1255,209 @@ func TestExpandQueryWithACLPermissions(t *testing.T) { testutil.CompareJSON(t, `{"me":[{"name":"RandomGuy","age":23, "nickname":"RG"},{"name":"RandomGuy2","age":25, "nickname":"RG2"}]}`, string(resp.GetJson())) +} + +func TestValQueryWithACLPermissions(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Second) + defer cancel() + dg, err := testutil.DgraphClientWithGroot(testutil.SockAddr) + require.NoError(t, err) + + testutil.DropAll(t, dg) + + op := api.Operation{Schema: ` + name : string @index(exact) . + nickname : string @index(exact) . + age : int . + type TypeName { + name: string + nickname: string + age: int + } + `} + require.NoError(t, dg.Alter(ctx, &op)) + + resetUser(t) + + accessJwt, _, err := testutil.HttpLogin(&testutil.LoginParams{ + Endpoint: adminEndpoint, + UserID: "groot", + Passwd: "password", + }) + require.NoError(t, err, "login failed") + + createGroup(t, accessJwt, devGroup) + // createGroup(t, accessJwt, sreGroup) + + // addRulesToGroup(t, accessJwt, sreGroup, []rule{{"age", Read.Code}, {"name", Write.Code}}) + addToGroup(t, accessJwt, userid, devGroup) + + txn := dg.NewTxn() + mutation := &api.Mutation{ + SetNquads: []byte(` + _:a "RandomGuy" . + _:a "23" . + _:a "RG" . + _:a "TypeName" . + _:b "RandomGuy2" . + _:b "25" . + _:b "RG2" . + _:b "TypeName" . + `), + CommitNow: true, + } + _, err = txn.Mutate(ctx, mutation) + require.NoError(t, err) + + query := `{q1(func: has(name)){ + v as name + a as age + } + q2(func: eq(val(v), "RandomGuy")) { + val(v) + val(a) + }}` + + // Test that groot has access to all the predicates + resp, err := dg.NewReadOnlyTxn().Query(ctx, query) + require.NoError(t, err, "Error while querying data") + testutil.CompareJSON(t, `{"q1":[{"name":"RandomGuy","age":23},{"name":"RandomGuy2","age":25}],"q2":[{"val(v)":"RandomGuy","val(a)":23}]}`, + string(resp.GetJson())) + + // All test cases + tests := []struct { + input string + descriptionNoPerm string + outputNoPerm string + descriptionNamePerm string + outputNamePerm string + descriptionNameAgePerm string + outputNameAgePerm string + }{ + { + ` + { + q1(func: has(name)) { + v as name + a as age + } + q2(func: eq(val(v), "RandomGuy")) { + val(v) + val(a) + } + } + `, + "alice doesn't have access to name or age", + `{}`, + + `alice has access to name`, + `{"q1":[{"name":"RandomGuy"},{"name":"RandomGuy2"}],"q2":[{"val(v)":"RandomGuy"}]}`, + + "alice has access to name and age", + `{"q1":[{"name":"RandomGuy","age":23},{"name":"RandomGuy2","age":25}],"q2":[{"val(v)":"RandomGuy","val(a)":23}]}`, + }, + { + `{ + q1(func: has(name) ) { + a as age + } + q2(func: has(name) ) { + val(a) + } + }`, + "alice doesn't have access to name or age", + `{}`, + + `alice has access to name`, + `{"q1":[],"q2":[]}`, + + "alice has access to name and age", + `{"q1":[{"age":23},{"age":25}],"q2":[{"val(a)":23},{"val(a)":25}]}`, + }, + { + `{ + f as q1(func: has(name) ) { + n as name + a as age + } + q2(func: uid(f), orderdesc: val(a) ) { + name + val(n) + val(a) + } + }`, + "alice doesn't have access to name or age", + `{"q2":[]}`, + + `alice has access to name`, + `{"q1":[{"name":"RandomGuy"},{"name":"RandomGuy2"}], + "q2":[{"name":"RandomGuy","val(n)":"RandomGuy"},{"name":"RandomGuy2","val(n)":"RandomGuy2"}]}`, + + "alice has access to name and age", + `{"q1":[{"name":"RandomGuy","age":23},{"name":"RandomGuy2","age":25}], + "q2":[{"name":"RandomGuy2","val(n)":"RandomGuy2","val(a)":25},{"name":"RandomGuy","val(n)":"RandomGuy","val(a)":23}]}`, + }, + } + + userClient, err := testutil.DgraphClient(testutil.SockAddr) + require.NoError(t, err) + time.Sleep(6 * time.Second) + + err = userClient.Login(ctx, userid, userpassword) + require.NoError(t, err) + + // Query via user when user has no permissions + for _, tc := range tests { + desc := tc.descriptionNoPerm + t.Run(desc, func(t *testing.T) { + resp, err := userClient.NewTxn().Query(ctx, tc.input) + require.NoError(t, err) + testutil.CompareJSON(t, tc.outputNoPerm, string(resp.Json)) + }) + } + + // Login to groot to modify accesses (1) + accessJwt, _, err = testutil.HttpLogin(&testutil.LoginParams{ + Endpoint: adminEndpoint, + UserID: "groot", + Passwd: "password", + }) + require.NoError(t, err, "login failed") + + // Give read access of to dev + addRulesToGroup(t, accessJwt, devGroup, []rule{{"name", Read.Code}}) + time.Sleep(6 * time.Second) + + for _, tc := range tests { + desc := tc.descriptionNamePerm + t.Run(desc, func(t *testing.T) { + resp, err := userClient.NewTxn().Query(ctx, tc.input) + require.NoError(t, err) + testutil.CompareJSON(t, tc.outputNamePerm, string(resp.Json)) + }) + } + + // Login to groot to modify accesses (1) + accessJwt, _, err = testutil.HttpLogin(&testutil.LoginParams{ + Endpoint: adminEndpoint, + UserID: "groot", + Passwd: "password", + }) + require.NoError(t, err, "login failed") + + // Give read access of and to dev + addRulesToGroup(t, accessJwt, devGroup, []rule{{"name", Read.Code}, {"age", Read.Code}}) + time.Sleep(6 * time.Second) + + for _, tc := range tests { + desc := tc.descriptionNameAgePerm + t.Run(desc, func(t *testing.T) { + resp, err := userClient.NewTxn().Query(ctx, tc.input) + require.NoError(t, err) + testutil.CompareJSON(t, tc.outputNameAgePerm, string(resp.Json)) + }) + } + } func TestNewACLPredicates(t *testing.T) { ctx, _ := context.WithTimeout(context.Background(), 100*time.Second)