Skip to content

Commit

Permalink
feat: Support batch enforce and add corresponding test.
Browse files Browse the repository at this point in the history
  • Loading branch information
AsakusaRinne committed Aug 29, 2022
1 parent aa73b71 commit c107068
Show file tree
Hide file tree
Showing 5 changed files with 376 additions and 1 deletion.
168 changes: 168 additions & 0 deletions Casbin.UnitTests/ModelTests/EnforcerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,74 @@ public void TestRbacModelInMemory()
TestEnforce(e, "bob", "data2", "write", true);
}

[Fact]
public void TestRbacBatchEnforceInMemory()
{
IModel m = DefaultModel.Create();
m.AddDef("r", "r", "sub, obj, act");
m.AddDef("p", "p", "sub, obj, act");
m.AddDef("g", "g", "_, _");
m.AddDef("e", "e", "some(where (p.eft == allow))");
m.AddDef("m", "m", "g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act");

Enforcer e = new(m);

e.AddPermissionForUser("alice", "data1", "read");
e.AddPermissionForUser("bob", "data2", "write");
e.AddPermissionForUser("data2_admin", "data2", "read");
e.AddPermissionForUser("data2_admin", "data2", "write");
e.AddRoleForUser("alice", "data2_admin");

IEnumerable<(RequestValues<string, string, string>, bool)> testCases =
new (RequestValues<string, string, string>, bool)[]
{
(Request.CreateValues("alice", "data1", "read"), true),
(Request.CreateValues("alice", "data1", "write"), false),
(Request.CreateValues("alice", "data2", "read"), true),
(Request.CreateValues("alice", "data2", "write"), true),
(Request.CreateValues("bob", "data1", "read"), false),
(Request.CreateValues("bob", "data1", "write"), false),
(Request.CreateValues("bob", "data2", "read"), false),
(Request.CreateValues("bob", "data2", "write"), true)
};

TestBatchEnforce(e, testCases);
}

[Fact]
public void TestRbacBatchEnforceParallelInMemory()
{
IModel m = DefaultModel.Create();
m.AddDef("r", "r", "sub, obj, act");
m.AddDef("p", "p", "sub, obj, act");
m.AddDef("g", "g", "_, _");
m.AddDef("e", "e", "some(where (p.eft == allow))");
m.AddDef("m", "m", "g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act");

Enforcer e = new(m);

e.AddPermissionForUser("alice", "data1", "read");
e.AddPermissionForUser("bob", "data2", "write");
e.AddPermissionForUser("data2_admin", "data2", "read");
e.AddPermissionForUser("data2_admin", "data2", "write");
e.AddRoleForUser("alice", "data2_admin");

IEnumerable<(RequestValues<string, string, string>, bool)> testCases =
new (RequestValues<string, string, string>, bool)[]
{
(Request.CreateValues("alice", "data1", "read"), true),
(Request.CreateValues("alice", "data1", "write"), false),
(Request.CreateValues("alice", "data2", "read"), true),
(Request.CreateValues("alice", "data2", "write"), true),
(Request.CreateValues("bob", "data1", "read"), false),
(Request.CreateValues("bob", "data1", "write"), false),
(Request.CreateValues("bob", "data2", "read"), false),
(Request.CreateValues("bob", "data2", "write"), true)
};

TestBatchEnforceParallel(e, testCases);
}

[Fact]
public async Task TestRbacModelInMemoryAsync()
{
Expand Down Expand Up @@ -338,6 +406,40 @@ public async Task TestRbacModelInMemoryAsync()
await TestEnforceAsync(e, "bob", "data2", "write", true);
}

[Fact]
public void TestRbacBatchEnforceInMemoryAsync()
{
IModel m = DefaultModel.Create();
m.AddDef("r", "r", "sub, obj, act");
m.AddDef("p", "p", "sub, obj, act");
m.AddDef("g", "g", "_, _");
m.AddDef("e", "e", "some(where (p.eft == allow))");
m.AddDef("m", "m", "g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act");

Enforcer e = new(m);

e.AddPermissionForUserAsync("alice", "data1", "read");
e.AddPermissionForUserAsync("bob", "data2", "write");
e.AddPermissionForUserAsync("data2_admin", "data2", "read");
e.AddPermissionForUserAsync("data2_admin", "data2", "write");
e.AddRoleForUserAsync("alice", "data2_admin");

IEnumerable<(RequestValues<string, string, string>, bool)> testCases =
new (RequestValues<string, string, string>, bool)[]
{
(Request.CreateValues("alice", "data1", "read"), true),
(Request.CreateValues("alice", "data1", "write"), false),
(Request.CreateValues("alice", "data2", "read"), true),
(Request.CreateValues("alice", "data2", "write"), true),
(Request.CreateValues("bob", "data1", "read"), false),
(Request.CreateValues("bob", "data1", "write"), false),
(Request.CreateValues("bob", "data2", "read"), false),
(Request.CreateValues("bob", "data2", "write"), true)
};

TestBatchEnforceAsync(e, testCases);
}

[Fact]
public void TestRbacModelInMemory2()
{
Expand Down Expand Up @@ -1030,6 +1132,50 @@ public void TestEnforceWithMatcherApi()
e.TestEnforceWithMatcher(matcher, "bob", "data2", "write", false);
}

[Fact]
public void TestBatchEnforceWithMatcherApi()
{
Enforcer e = new(_testModelFixture.GetBasicTestModel());
string matcher = "r.sub != p.sub && r.obj == p.obj && r.act == p.act";

IEnumerable<(RequestValues<string, string, string>, bool)> testCases =
new (RequestValues<string, string, string>, bool)[]
{
(Request.CreateValues("alice", "data1", "read"), false),
(Request.CreateValues("alice", "data1", "write"), false),
(Request.CreateValues("alice", "data2", "read"), false),
(Request.CreateValues("alice", "data2", "write"), true),
(Request.CreateValues("bob", "data1", "read"), true),
(Request.CreateValues("bob", "data1", "write"), false),
(Request.CreateValues("bob", "data2", "read"), false),
(Request.CreateValues("bob", "data2", "write"), false)
};

e.TestBatchEnforceWithMatcher(matcher, testCases);
}

[Fact]
public void TestBatchEnforceWithMatcherParallel()
{
Enforcer e = new(_testModelFixture.GetBasicTestModel());
string matcher = "r.sub != p.sub && r.obj == p.obj && r.act == p.act";

IEnumerable<(RequestValues<string, string, string>, bool)> testCases =
new (RequestValues<string, string, string>, bool)[]
{
(Request.CreateValues("alice", "data1", "read"), false),
(Request.CreateValues("alice", "data1", "write"), false),
(Request.CreateValues("alice", "data2", "read"), false),
(Request.CreateValues("alice", "data2", "write"), true),
(Request.CreateValues("bob", "data1", "read"), true),
(Request.CreateValues("bob", "data1", "write"), false),
(Request.CreateValues("bob", "data2", "read"), false),
(Request.CreateValues("bob", "data2", "write"), false)
};

e.TestBatchEnforceWithMatcherParallel(matcher, testCases);
}

[Fact]
public async Task TestEnforceWithMatcherAsync()
{
Expand All @@ -1046,6 +1192,28 @@ public async Task TestEnforceWithMatcherAsync()
await e.TestEnforceWithMatcherAsync(matcher, "bob", "data2", "write", false);
}

[Fact]
public void TestBatchEnforceWithMatcherApiAsync()
{
Enforcer e = new(_testModelFixture.GetBasicTestModel());
string matcher = "r.sub != p.sub && r.obj == p.obj && r.act == p.act";

IEnumerable<(RequestValues<string, string, string>, bool)> testCases =
new (RequestValues<string, string, string>, bool)[]
{
(Request.CreateValues("alice", "data1", "read"), false),
(Request.CreateValues("alice", "data1", "write"), false),
(Request.CreateValues("alice", "data2", "read"), false),
(Request.CreateValues("alice", "data2", "write"), true),
(Request.CreateValues("bob", "data1", "read"), true),
(Request.CreateValues("bob", "data1", "write"), false),
(Request.CreateValues("bob", "data2", "read"), false),
(Request.CreateValues("bob", "data2", "write"), false)
};

TestBatchEnforceWithMatcherAsync(e, matcher, testCases);
}

[Fact]
public void TestEnforceExWithMatcherApi()
{
Expand Down
56 changes: 56 additions & 0 deletions Casbin.UnitTests/Util/TestUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Linq;
using System.Threading.Tasks;
using Casbin.Rbac;
using Casbin.Model;
using Casbin.Util;
using Xunit;

Expand All @@ -25,6 +26,32 @@ internal static void TestEnforce<T1, T2, T3>(IEnforcer e, T1 sub, T2 obj, T3 act
internal static async Task TestEnforceAsync<T1, T2, T3>(IEnforcer e, T1 sub, T2 obj, T3 act, bool res) =>
Assert.Equal(res, await e.EnforceAsync(sub, obj, act));

internal static void TestBatchEnforce<T>(IEnforcer e, IEnumerable<(T, bool)> values) where T : IRequestValues =>
Assert.True(values.Select(x => x.Item2).SequenceEqual(e.BatchEnforce(values.Select(x => x.Item1))));

internal static void TestBatchEnforceParallel<T>(Enforcer e, IEnumerable<(T, bool)> values) where T : IRequestValues =>
Assert.True(values.Select(x => x.Item2).SequenceEqual(e.BatchEnforceParallel(values.Select(x => x.Item1).ToList())));

internal static async void TestBatchEnforceAsync<T>(IEnforcer e, IEnumerable<(T, bool)> values) where T : IRequestValues
{
#if !NET452
var res = e.BatchEnforceAsync(values.Select(x => x.Item1));
#else
var res = await e.BatchEnforceAsync(values.Select(x => x.Item1));
#endif
var expectedResults = values.Select(x => x.Item2);
var expectedResultEnumerator = expectedResults.GetEnumerator();
#if !NET452
await foreach(var item in res)
#else
foreach(var item in res)
#endif
{
expectedResultEnumerator.MoveNext();
Assert.Equal(expectedResultEnumerator.Current, item);
}
}

internal static void TestDomainEnforce<T1, T2, T3, T4>(IEnforcer e, T1 sub, T2 dom, T3 obj, T4 act, bool res) =>
Assert.Equal(res, e.Enforce(sub, dom, obj, act));

Expand All @@ -33,6 +60,35 @@ internal static void TestEnforceWithMatcher<T1, T2, T3>(this IEnforcer e, string

internal static async Task TestEnforceWithMatcherAsync<T1, T2, T3>(this IEnforcer e, string matcher, T1 sub, T2 obj,
T3 act, bool res) => Assert.Equal(res, await e.EnforceWithMatcherAsync(matcher, sub, obj, act));

internal static void TestBatchEnforceWithMatcher<T>(this IEnforcer e, string matcher, IEnumerable<(T, bool)> values)
where T : IRequestValues =>
Assert.True(values.Select(x => x.Item2).SequenceEqual(e.BatchEnforceWithMatcher(matcher, values.Select(x => x.Item1))));

internal static void TestBatchEnforceWithMatcherParallel<T>(this Enforcer e, string matcher, IEnumerable<(T, bool)> values)
where T : IRequestValues =>
Assert.True(values.Select(x => x.Item2).SequenceEqual(e.BatchEnforceWithMatcherParallel<T>(matcher, values.Select(x => x.Item1).ToList())));

internal static async void TestBatchEnforceWithMatcherAsync<T>(IEnforcer e, string matcher, IEnumerable<(T, bool)> values)
where T : IRequestValues
{
#if !NET452
var res = e.BatchEnforceWithMatcherAsync(matcher, values.Select(x => x.Item1));
#else
var res = await e.BatchEnforceWithMatcherAsync(matcher, values.Select(x => x.Item1));
#endif
var expectedResults = values.Select(x => x.Item2);
var expectedResultEnumerator = expectedResults.GetEnumerator();
#if !NET452
await foreach(var item in res)
#else
foreach(var item in res)
#endif
{
expectedResultEnumerator.MoveNext();
Assert.Equal(expectedResultEnumerator.Current, item);
}
}

internal static void TestEnforceEx<T1, T2, T3>(IEnforcer e, T1 sub, T2 obj, T3 act, List<string> res)
{
Expand Down
31 changes: 30 additions & 1 deletion Casbin/Abstractions/IEnforcer.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
using System.Threading.Tasks;
using System.Collections.Generic;
using Casbin.Caching;
using Casbin.Effect;
using Casbin.Evaluation;
using Casbin.Model;
using Casbin.Persist;
using Casbin.Rbac;
#if !NET452
using Microsoft.Extensions.Logging;
#endif

namespace Casbin
{
#if !NET452
using BatchEnforceAsyncResults = IAsyncEnumerable<bool>;
#else
using BatchEnforceAsyncResults = Task<IEnumerable<bool>>;
#endif

/// <summary>
/// IEnforcer is the API interface of Enforcer
/// </summary>
Expand Down Expand Up @@ -60,5 +67,27 @@ public interface IEnforcer
/// can be class instances if ABAC is used.</param>
/// <returns>Whether to allow the request.</returns>
public Task<bool> EnforceAsync<TRequest>(EnforceContext context, TRequest requestValues) where TRequest : IRequestValues;

/// <summary>
/// Decides whether some "subject" can access corresponding "object" with the operation
/// "action", input parameters are usually: (sub, obj, act).
/// </summary>
/// <param name="context">Enforce context include all status on enforcing</param>
/// <param name="requestValues">The requests needs to be mediated, whose element is usually an array of strings
/// but can be class instances if ABAC is used.</param>
/// <returns>Whether to allow the requests.</returns>
public IEnumerable<bool> BatchEnforce<TRequest>(EnforceContext context, IEnumerable<TRequest> requestValues)
where TRequest : IRequestValues;

/// <summary>
/// Decides whether some "subject" can access corresponding "object" with the operation
/// "action", input parameters are usually: (sub, obj, act).
/// </summary>
/// <param name="context">Enforce context include all status on enforcing</param>
/// <param name="requestValues">The requests needs to be mediated, whose element is usually an array of strings
/// but can be class instances if ABAC is used.</param>
/// <returns>Whether to allow the requests.</returns>
public BatchEnforceAsyncResults BatchEnforceAsync<TRequest>(EnforceContext context, IEnumerable<TRequest> requestValues)
where TRequest : IRequestValues;
}
}
68 changes: 68 additions & 0 deletions Casbin/Enforcer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,74 @@ public async Task<bool> EnforceAsync<TRequest>(EnforceContext context, TRequest
return result;
}

/// <summary>
/// Decides whether some "subject" can access corresponding "object" with the operation
/// "action", input parameters are usually: (sub, obj, act).
/// </summary>
/// <param name="context">Enforce context include all status on enforcing</param>
/// <param name="requestValues">The requests needs to be mediated, whose element is usually an array of strings
/// but can be class instances if ABAC is used.</param>
/// <returns>Whether to allow the requests.</returns>
public IEnumerable<bool> BatchEnforce<TRequest>(EnforceContext context,
IEnumerable<TRequest> requestValues) where TRequest : IRequestValues
{
foreach(var requestValue in requestValues){
yield return this.Enforce(context, requestValue);
}
}

/// <summary>
/// Decides whether some "subject" can access corresponding "object" with the operation
/// "action", input parameters are usually: (sub, obj, act). The method uses multi-thread
/// to accelerate the process.
/// </summary>
/// <param name="context">Enforce context include all status on enforcing</param>
/// <param name="requestValues">The requests needs to be mediated, whose element is usually an array of strings
/// but can be class instances if ABAC is used.</param>
/// <param name="maxDegreeOfParallelism">The max degree of parallelism of the process.</param>
/// <returns>Whether to allow the requests.</returns>
public IEnumerable<bool> ParallelBatchEnforce<TRequest>(EnforceContext context,
IReadOnlyList<TRequest> requestValues, int maxDegreeOfParallelism = -1) where TRequest : IRequestValues
{
int n = requestValues.Count;
if(n == 0) return new bool[] { };
bool[] res = new bool[n];
Parallel.For(0, n, new ParallelOptions() { MaxDegreeOfParallelism = maxDegreeOfParallelism }, Index =>
{
res[Index] = this.Enforce(context, requestValues[Index]);
});
return res;
}

/// <summary>
/// Decides whether some "subject" can access corresponding "object" with the operation
/// "action", input parameters are usually: (sub, obj, act).
/// </summary>
/// <param name="context">Enforce context include all status on enforcing</param>
/// <param name="requestValues">The requests needs to be mediated, whose element is usually an array of strings
/// but can be class instances if ABAC is used.</param>
/// <returns>Whether to allow the requests.</returns>
#if !NET452
public async IAsyncEnumerable<bool> BatchEnforceAsync<TRequest>(EnforceContext context, IEnumerable<TRequest> requestValues)
where TRequest : IRequestValues
{
foreach(var requestValue in requestValues){
bool res = await this.EnforceAsync(context, requestValue);
yield return res;
}
}
#else
public async Task<IEnumerable<bool>> BatchEnforceAsync<TRequest>(EnforceContext context, IEnumerable<TRequest> requestValues)
where TRequest : IRequestValues
{
List<bool> res = new List<bool>();
foreach(var requestValue in requestValues){
res.Add(await this.EnforceAsync(context, requestValue));
}
return res;
}
#endif

private Task<bool> InternalEnforceAsync<TRequest>(EnforceContext context, IPolicyManager policyManager,
TRequest requestValues) where TRequest : IRequestValues
{
Expand Down
Loading

0 comments on commit c107068

Please sign in to comment.