Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions internal/check/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/ory/keto/internal/driver/config"
"github.com/ory/keto/internal/namespace"
"github.com/ory/keto/internal/namespace/ast"
"github.com/ory/keto/internal/persistence"
"github.com/ory/keto/internal/relationtuple"
"github.com/ory/keto/internal/x"
"github.com/ory/keto/internal/x/graph"
Expand All @@ -31,6 +32,7 @@ type (
relationtuple.ManagerProvider
config.Provider
x.LoggerProvider
Persister() persistence.Persister
}

EngineOpt func(*Engine)
Expand Down Expand Up @@ -159,11 +161,32 @@ func (e *Engine) checkDirect(r *relationTuple, restDepth int) checkgroup.CheckFu
e.d.Logger().
WithField("request", r.String()).
Trace("check direct")
q := r.ToQuery()
if w := argsFromCtx(ctx); w != nil && len(w.Args) == 1 {
// fill subject with the passed argument
if v, ok := w.Args[0].(ast.StringLiteralArg); ok {
u, err := e.d.Persister().MapStringsToUUIDs(ctx, v.Value(ctx))
if err != nil {
resultCh <- checkgroup.Result{
Membership: checkgroup.NotMember,
}
e.d.Logger().WithField("request", r.String()).Error("failed to check direct", err)
return
}
q.Subject = &relationtuple.SubjectID{ID: u[0]}
}
}
if rels, _, err := e.d.RelationTupleManager().GetRelationTuples(
ctx,
r.ToQuery(),
q,
x.WithSize(1),
); err == nil && len(rels) > 0 {
if q.Subject != r.Subject {
// fix the Tree
t := *r
t.Subject = q.Subject
r = &t
}
resultCh <- checkgroup.Result{
Membership: checkgroup.IsMember,
Tree: &ketoapi.Tree[*relationtuple.RelationTuple]{
Expand Down Expand Up @@ -195,11 +218,26 @@ func (e *Engine) checkIsAllowed(ctx context.Context, r *relationTuple, restDepth
WithField("request", r.String()).
Trace("check is allowed")

relation, err := e.astRelationFor(ctx, r)
if w := argsFromCtx(ctx); w != nil && len(w.Args) > 0 && len(relation.Params) > 0 {
// map arg-name to value
for i, p := range relation.Params {
if p != "ctx" {
if w.Mapping == nil {
w.Mapping = make(map[string]ast.Arg)
}
w.Mapping[p] = w.Args[i]
}
}
}

g := checkgroup.New(ctx)
g.Add(e.checkDirect(r, restDepth-1))
g.Add(e.checkExpandSubject(r, restDepth))
// do not make checks for faked helper relations
if relation == nil || len(relation.Params) == 1 {
g.Add(e.checkDirect(r, restDepth-1))
g.Add(e.checkExpandSubject(r, restDepth))
}

relation, err := e.astRelationFor(ctx, r)
if err != nil {
g.Add(checkgroup.ErrorFunc(err))
} else if relation != nil && relation.SubjectSetRewrite != nil {
Expand Down Expand Up @@ -231,6 +269,7 @@ func (e *Engine) astRelationFor(ctx context.Context, r *relationTuple) (*ast.Rel

for _, rel := range ns.Relations {
if rel.Name == r.Relation {
r.Formula = &rel
return &rel, nil
}
}
Expand Down
17 changes: 17 additions & 0 deletions internal/check/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/ory/keto/internal/driver"
"github.com/ory/keto/internal/driver/config"
"github.com/ory/keto/internal/namespace"
"github.com/ory/keto/internal/persistence"
"github.com/ory/keto/internal/relationtuple"
"github.com/ory/keto/internal/x"
"github.com/ory/keto/ketoapi"
Expand Down Expand Up @@ -42,6 +43,22 @@ func newDepsProvider(t testing.TB, namespaces []*namespace.Namespace, pageOpts .
}
}

func (d *deps) Persister() persistence.Persister {
return persister{}
}

type persister struct {
persistence.Persister
}

func (p persister) MapStringsToUUIDs(ctx context.Context, s ...string) ([]uuid.UUID, error) {
u := make([]uuid.UUID, len(s))
for i, v := range s {
u[i] = toUUID(v)
}
return u, nil
}

func toUUID(s string) uuid.UUID {
return uuid.NewV5(uuid.Nil, s)
}
Expand Down
33 changes: 33 additions & 0 deletions internal/check/rewrites.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ func (e *Engine) checkComputedSubjectSet(
WithField("computed subjectSet relation", subjectSet.Relation).
Trace("check computed subjectSet")

ctx = wrapArgs(ctx, subjectSet.Args)
return e.checkIsAllowed(
ctx,
&relationTuple{
Expand Down Expand Up @@ -227,6 +228,7 @@ func (e *Engine) checkTupleToSubjectSet(
tuples []*relationTuple
err error
)
ctx = wrapArgs(ctx, subjectSet.Args)
g := checkgroup.New(ctx)
for nextPage = "x"; nextPage != "" && !g.Done(); prevPage = nextPage {
tuples, nextPage, err = e.d.RelationTupleManager().GetRelationTuples(
Expand Down Expand Up @@ -261,3 +263,34 @@ func (e *Engine) checkTupleToSubjectSet(
resultCh <- g.Result()
}
}

type argsCtxKey struct{}

func newCtxWithArgs(ctx context.Context, args *ArgsWrapper) context.Context {
return context.WithValue(ctx, argsCtxKey{}, args)
}

func argsFromCtx(ctx context.Context) *ArgsWrapper {
args, _ := ctx.Value(argsCtxKey{}).(*ArgsWrapper)
return args
}

type ArgsWrapper struct {
Args []ast.Arg
Mapping map[string]ast.Arg
}

func wrapArgs(ctx context.Context, args []ast.Arg) context.Context {
w := argsFromCtx(ctx)
if w != nil {
// replace named-args with real values
a := append([]ast.Arg{}, args...)
for i, p := range a {
if _, ok := p.(ast.NamedArg); ok {
a[i] = w.Mapping[p.Value(ctx)]
}
}
args = a
}
return newCtxWithArgs(ctx, &ArgsWrapper{Args: args})
}
Loading