Skip to content

Commit

Permalink
feat/skill_intent_blacklist_per_client (#89)
Browse files Browse the repository at this point in the history
* feat/skill_intent_blacklist_per_client

* add script

  blacklist-intent  blacklist intents from being triggered by a client

  blacklist-skill   blacklist skills from being triggered by a client

* unblacklist commands

* dont require restart
  • Loading branch information
JarbasAl authored Jul 5, 2024
1 parent 204bd39 commit 16eee11
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 16 deletions.
42 changes: 35 additions & 7 deletions hivemind_core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from dataclasses import dataclass, field
from enum import Enum, IntEnum
from typing import List, Dict, Optional
import pgpy

import pgpy
from ovos_bus_client import MessageBusClient
from ovos_bus_client.message import Message
from ovos_bus_client.session import Session
Expand All @@ -13,15 +13,16 @@
from tornado import ioloop
from tornado.websocket import WebSocketHandler

from hivemind_bus_client.identity import NodeIdentity
from hivemind_bus_client.message import HiveMessage, HiveMessageType
from hivemind_bus_client.serialization import decode_bitstring, get_bitstring
from hivemind_bus_client.identity import NodeIdentity
from hivemind_bus_client.util import (
decrypt_bin,
encrypt_bin,
decrypt_from_json,
encrypt_as_json,
)
from hivemind_core.database import ClientDatabase


class ProtocolVersion(IntEnum):
Expand Down Expand Up @@ -63,9 +64,15 @@ class HiveMindClientConnection:
pswd_handshake: Optional[PasswordHandShake] = None
socket: Optional[WebSocketHandler] = None
crypto_key: Optional[str] = None
blacklist: List[str] = field(
msg_blacklist: List[str] = field(
default_factory=list
) # list of ovos message_type to never be sent to this client
skill_blacklist: List[str] = field(
default_factory=list
) # list of skill_id that can't match for this client
intent_blacklist: List[str] = field(
default_factory=list
) # list of skill_id:intent_name that can't match for this client
allowed_types: List[str] = field(
default_factory=list
) # list of ovos message_type to allow to be sent from this client
Expand All @@ -88,7 +95,7 @@ def send(self, message: HiveMessage):
else:
_msg_type = message.payload.msg_type

if _msg_type in self.blacklist:
if _msg_type in self.msg_blacklist:
return LOG.debug(
f"message type {_msg_type} " f"is blacklisted for {self.peer}"
)
Expand Down Expand Up @@ -673,8 +680,30 @@ def handle_intercom_message(
return False

# HiveMind mycroft bus messages - from slave -> master
def _update_blacklist(self, message: Message, client: HiveMindClientConnection):
LOG.debug("replacing message metadata with hivemind client session")
message.context["session"] = client.sess.serialize()

# update blacklist from db, to account for changes without requiring a restart
with ClientDatabase() as users:
user = users.get_client_by_api_key(client.key)
client.skill_blacklist = user.blacklist.get("skills", [])
client.intent_blacklist = user.blacklist.get("intents", [])

# inject client specific blacklist into session
if "blacklisted_skills" not in message.context["session"]:
message.context["session"]["blacklisted_skills"] = []
if "blacklisted_intents" not in message.context["session"]:
message.context["session"]["blacklisted_intents"] = []

message.context["session"]["blacklisted_skills"] += [s for s in client.skill_blacklist
if s not in message.context["session"]["blacklisted_skills"]]
message.context["session"]["blacklisted_intents"] += [s for s in client.intent_blacklist
if s not in message.context["session"]["blacklisted_intents"]]
return message

def handle_inject_mycroft_msg(
self, message: Message, client: HiveMindClientConnection
self, message: Message, client: HiveMindClientConnection
):
"""
message (Message): mycroft bus message object
Expand All @@ -688,8 +717,7 @@ def handle_inject_mycroft_msg(
return

# ensure client specific session data is injected in query to ovos
LOG.debug("replacing message metadata with hivemind client session")
message.context["session"] = client.sess.serialize()
message = self._update_blacklist(message, client)
if message.msg_type == "speak":
message.context["destination"] = ["audio"] # make audible, this is injected "speak" command
elif message.context.get("destination") is None:
Expand Down
226 changes: 219 additions & 7 deletions hivemind_core/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,13 @@ def list_clients():
default="hivemind",
)
def listen(
ovos_bus_address: str,
ovos_bus_port: int,
host: str,
port: int,
ssl: bool,
cert_dir: str,
cert_name: str,
ovos_bus_address: str,
ovos_bus_port: int,
host: str,
port: int,
ssl: bool,
cert_dir: str,
cert_name: str,
):
from hivemind_core.service import HiveMindService

Expand All @@ -216,5 +216,217 @@ def listen(
service.run()


@hmcore_cmds.command(help="blacklist skills from being triggered by a client", name="blacklist-skill")
@click.argument("skill_id", required=True, type=str)
@click.argument("node_id", required=False, type=int)
def blacklist_skill(skill_id, node_id):
if not node_id:
# list clients and prompt for id using rich
table = Table(title="HiveMind Clients")
table.add_column("ID", justify="right", style="cyan", no_wrap=True)
table.add_column("Name", style="magenta")
table.add_column("Allowed Msg Types", style="yellow")
_choices = []
for client in ClientDatabase():
if client["client_id"] != -1:
table.add_row(
str(client["client_id"]),
client["name"],
str(client.get("allowed_types", [])),
)
_choices.append(str(client["client_id"]))

if not _choices:
print("No clients found!")
exit()
elif len(_choices) > 1:
console = Console()
console.print(table)
_exit = str(max(int(i) for i in _choices) + 1)
node_id = Prompt.ask(
f"To which client you want to blacklist '{skill_id}'? ({_exit}='Exit')",
choices=_choices + [_exit],
)
if node_id == _exit:
console.print("User exit", style="red")
exit()
else:
node_id = _choices[0]

with ClientDatabase() as db:
for client in db:
if client["client_id"] == int(node_id):
blacklist = client.get("blacklist", {"messages": [], "skills": [], "intents": []})
if skill_id in blacklist["skills"]:
print(f"Client {client['name']} already blacklisted '{skill_id}'")
exit()

blacklist["skills"].append(skill_id)
client["blacklist"] = blacklist
item_id = db.get_item_id(client)
db.update_item(item_id, client)
print(f"Blacklisted '{skill_id}' for {client['name']}")
break


@hmcore_cmds.command(help="remove skills from a client blacklist", name="unblacklist-skill")
@click.argument("skill_id", required=True, type=str)
@click.argument("node_id", required=False, type=int)
def unblacklist_skill(skill_id, node_id):
if not node_id:
# list clients and prompt for id using rich
table = Table(title="HiveMind Clients")
table.add_column("ID", justify="right", style="cyan", no_wrap=True)
table.add_column("Name", style="magenta")
table.add_column("Allowed Msg Types", style="yellow")
_choices = []
for client in ClientDatabase():
if client["client_id"] != -1:
table.add_row(
str(client["client_id"]),
client["name"],
str(client.get("allowed_types", [])),
)
_choices.append(str(client["client_id"]))

if not _choices:
print("No clients found!")
exit()
elif len(_choices) > 1:
console = Console()
console.print(table)
_exit = str(max(int(i) for i in _choices) + 1)
node_id = Prompt.ask(
f"To which client you want to blacklist '{skill_id}'? ({_exit}='Exit')",
choices=_choices + [_exit],
)
if node_id == _exit:
console.print("User exit", style="red")
exit()
else:
node_id = _choices[0]

with ClientDatabase() as db:
for client in db:
if client["client_id"] == int(node_id):
blacklist = client.get("blacklist", {"messages": [], "skills": [], "intents": []})
if skill_id not in blacklist["skills"]:
print(f"'{skill_id}' is not blacklisted for client {client['name']}")
exit()

blacklist["skills"].pop(skill_id)
client["blacklist"] = blacklist
item_id = db.get_item_id(client)
db.update_item(item_id, client)
print(f"Blacklisted '{skill_id}' for {client['name']}")
break


@hmcore_cmds.command(help="blacklist intents from being triggered by a client", name="blacklist-intent")
@click.argument("intent_id", required=True, type=str)
@click.argument("node_id", required=False, type=int)
def blacklist_intent(intent_id, node_id):
if not node_id:
# list clients and prompt for id using rich
table = Table(title="HiveMind Clients")
table.add_column("ID", justify="right", style="cyan", no_wrap=True)
table.add_column("Name", style="magenta")
table.add_column("Allowed Msg Types", style="yellow")
_choices = []
for client in ClientDatabase():
if client["client_id"] != -1:
table.add_row(
str(client["client_id"]),
client["name"],
str(client.get("allowed_types", [])),
)
_choices.append(str(client["client_id"]))

if not _choices:
print("No clients found!")
exit()
elif len(_choices) > 1:
console = Console()
console.print(table)
_exit = str(max(int(i) for i in _choices) + 1)
node_id = Prompt.ask(
f"To which client you want to blacklist '{intent_id}'? ({_exit}='Exit')",
choices=_choices + [_exit],
)
if node_id == _exit:
console.print("User exit", style="red")
exit()
else:
node_id = _choices[0]

with ClientDatabase() as db:
for client in db:
if client["client_id"] == int(node_id):
blacklist = client.get("blacklist", {"messages": [], "skills": [], "intents": []})
if intent_id in blacklist["intents"]:
print(f"Client {client['name']} already blacklisted '{intent_id}'")
exit()

blacklist["intents"].append(intent_id)
client["blacklist"] = blacklist
item_id = db.get_item_id(client)
db.update_item(item_id, client)
print(f"Blacklisted '{intent_id}' for {client['name']}")
break


@hmcore_cmds.command(help="remove intents from a client blacklist", name="unblacklist-intent")
@click.argument("intent_id", required=True, type=str)
@click.argument("node_id", required=False, type=int)
def unblacklist_intent(intent_id, node_id):
if not node_id:
# list clients and prompt for id using rich
table = Table(title="HiveMind Clients")
table.add_column("ID", justify="right", style="cyan", no_wrap=True)
table.add_column("Name", style="magenta")
table.add_column("Allowed Msg Types", style="yellow")
_choices = []
for client in ClientDatabase():
if client["client_id"] != -1:
table.add_row(
str(client["client_id"]),
client["name"],
str(client.get("allowed_types", [])),
)
_choices.append(str(client["client_id"]))

if not _choices:
print("No clients found!")
exit()
elif len(_choices) > 1:
console = Console()
console.print(table)
_exit = str(max(int(i) for i in _choices) + 1)
node_id = Prompt.ask(
f"To which client you want to blacklist '{intent_id}'? ({_exit}='Exit')",
choices=_choices + [_exit],
)
if node_id == _exit:
console.print("User exit", style="red")
exit()
else:
node_id = _choices[0]

with ClientDatabase() as db:
for client in db:
if client["client_id"] == int(node_id):
blacklist = client.get("blacklist", {"messages": [], "skills": [], "intents": []})
if intent_id not in blacklist["intents"]:
print(f" '{intent_id}' not blacklisted for Client {client['name']} ")
exit()

blacklist["intents"].pop(intent_id)
client["blacklist"] = blacklist
item_id = db.get_item_id(client)
db.update_item(item_id, client)
print(f"Blacklisted '{intent_id}' for {client['name']}")
break


if __name__ == "__main__":
hmcore_cmds()
6 changes: 4 additions & 2 deletions hivemind_core/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def open(self):
name=name,
ip=self.request.remote_ip,
socket=self,
sess=Session(session_id="default"), # will be re-assigned once client sends it's own
sess=Session(session_id="default"), # will be re-assigned once client sends handshake
handshake=handshake,
loop=self.protocol.loop,
)
Expand All @@ -144,7 +144,9 @@ def open(self):
return

self.client.crypto_key = user.crypto_key
self.client.blacklist = user.blacklist.get("messages", [])
self.client.msg_blacklist = user.blacklist.get("messages", [])
self.client.skill_blacklist = user.blacklist.get("skills", [])
self.client.intent_blacklist = user.blacklist.get("intents", [])
self.client.allowed_types = user.allowed_types
self.client.can_broadcast = user.can_broadcast
self.client.can_propagate = user.can_propagate
Expand Down

0 comments on commit 16eee11

Please sign in to comment.