Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 40 additions & 10 deletions casbin/async_internal_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ async def save_policy(self):
else:
update_for_save_policy(self.model)
else:
self.watcher.update()
if inspect.iscoroutinefunction(self.watcher.update):
await self.watcher.update()
else:
self.watcher.update()

async def _add_policy(self, sec, ptype, rule):
"""async adds a rule to the current policy."""
Expand All @@ -133,7 +136,10 @@ async def _add_policy(self, sec, ptype, rule):
else:
update_for_add_policy(sec, ptype, rule)
else:
self.watcher.update()
if inspect.iscoroutinefunction(self.watcher.update):
await self.watcher.update()
else:
self.watcher.update()

rule_added = self.model.add_policy(sec, ptype, rule)

Expand Down Expand Up @@ -161,7 +167,10 @@ async def _add_policies(self, sec, ptype, rules):
else:
update_for_add_policies(sec, ptype, rules)
else:
self.watcher.update()
if inspect.iscoroutinefunction(self.watcher.update):
await self.watcher.update()
else:
self.watcher.update()

rules_added = self.model.add_policies(sec, ptype, rules)

Expand All @@ -180,7 +189,10 @@ async def _update_policy(self, sec, ptype, old_rule, new_rule):
return False

if self.watcher and self.auto_notify_watcher:
self.watcher.update()
if inspect.iscoroutinefunction(self.watcher.update):
await self.watcher.update()
else:
self.watcher.update()

return rule_updated

Expand All @@ -197,7 +209,10 @@ async def _update_policies(self, sec, ptype, old_rules, new_rules):
return False

if self.watcher and self.auto_notify_watcher:
self.watcher.update()
if inspect.iscoroutinefunction(self.watcher.update):
await self.watcher.update()
else:
self.watcher.update()

return rules_updated

Expand Down Expand Up @@ -225,7 +240,10 @@ async def _update_filtered_policies(self, sec, ptype, new_rules, field_index, *f
if sec == "g":
self.build_role_links()
if self.watcher and self.auto_notify_watcher:
self.watcher.update()
if inspect.iscoroutinefunction(self.watcher.update):
await self.watcher.update()
else:
self.watcher.update()
return is_rule_changed

async def _remove_policy(self, sec, ptype, rule):
Expand All @@ -247,7 +265,10 @@ async def _remove_policy(self, sec, ptype, rule):
else:
update_for_remove_policy(sec, ptype, rule)
else:
self.watcher.update()
if inspect.iscoroutinefunction(self.watcher.update):
await self.watcher.update()
else:
self.watcher.update()

return rule_removed

Expand All @@ -273,7 +294,10 @@ async def _remove_policies(self, sec, ptype, rules):
else:
update_for_remove_policies(sec, ptype, rules)
else:
self.watcher.update()
if inspect.iscoroutinefunction(self.watcher.update):
await self.watcher.update()
else:
self.watcher.update()

return rules_removed

Expand All @@ -296,7 +320,10 @@ async def _remove_filtered_policy(self, sec, ptype, field_index, *field_values):
else:
update_for_remove_filtered_policy(sec, ptype, field_index, *field_values)
else:
self.watcher.update()
if inspect.iscoroutinefunction(self.watcher.update):
await self.watcher.update()
else:
self.watcher.update()

return rule_removed

Expand All @@ -312,7 +339,10 @@ async def _remove_filtered_policy_returns_effects(self, sec, ptype, field_index,
return False

if self.watcher and self.auto_notify_watcher:
self.watcher.update()
if inspect.iscoroutinefunction(self.watcher.update):
await self.watcher.update()
else:
self.watcher.update()

return rule_removed

Expand Down
58 changes: 58 additions & 0 deletions tests/test_watcher_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,24 @@ def test_auto_notify_disabled(self):
self.assertEqual(w.notify_message, None)


class AsyncMinimalWatcher:
"""A minimal async watcher that only implements async update() method."""

def __init__(self):
self.update_count = 0

async def update(self):
"""update the policy"""
self.update_count += 1
return True

async def close(self):
pass

async def set_update_callback(self, callback):
pass


class TestAsyncWatcherEx(IsolatedAsyncioTestCase):
def get_enforcer(self, model=None, adapter=None):
return casbin.AsyncEnforcer(
Expand Down Expand Up @@ -365,3 +383,43 @@ async def test_auto_notify_disabled(self):

await e.remove_policies(rules)
self.assertEqual(w.notify_message, None)

async def test_async_minimal_watcher(self):
"""Test that a watcher with only async update() method works properly."""
e = self.get_enforcer(
get_examples("basic_model.conf"),
get_examples("basic_policy.csv"),
)
await e.load_policy()

w = AsyncMinimalWatcher()
e.set_watcher(w)
e.enable_auto_notify_watcher(True)

# Test save_policy
await e.save_policy()
self.assertEqual(w.update_count, 1)

# Test add_policy - fallback to update()
await e.add_policy("admin", "data1", "read")
self.assertEqual(w.update_count, 2)

# Test remove_policy - fallback to update()
await e.remove_policy("admin", "data1", "read")
self.assertEqual(w.update_count, 3)

# Test remove_filtered_policy - fallback to update()
await e.remove_filtered_policy(1, "data1")
self.assertEqual(w.update_count, 4)

# Test add_policies - fallback to update()
rules = [
["jack", "data4", "read"],
["katy", "data4", "write"],
]
await e.add_policies(rules)
self.assertEqual(w.update_count, 5)

# Test remove_policies - fallback to update()
await e.remove_policies(rules)
self.assertEqual(w.update_count, 6)