Skip to content

Commit

Permalink
feat/transformers
Browse files Browse the repository at this point in the history
closes #82
  • Loading branch information
JarbasAl committed Apr 20, 2024
1 parent 0eff030 commit 8c61722
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 19 deletions.
File renamed without changes.
64 changes: 46 additions & 18 deletions hivemind_core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,6 @@
from enum import Enum, IntEnum
from typing import List, Dict, Optional

from ovos_bus_client import MessageBusClient
from ovos_bus_client.message import Message
from ovos_bus_client.session import Session
from ovos_utils.log import LOG
from poorman_handshake import HandShake, PasswordHandShake
from tornado import ioloop
from tornado.websocket import WebSocketHandler

from hivemind_bus_client.message import HiveMessage, HiveMessageType
from hivemind_bus_client.serialization import decode_bitstring, get_bitstring
from hivemind_bus_client.util import (
Expand All @@ -19,6 +11,17 @@
decrypt_from_json,
encrypt_as_json,
)
from ovos_bus_client import MessageBusClient
from ovos_bus_client.message import Message
from ovos_bus_client.session import Session
from ovos_bus_client.util import get_message_lang
from ovos_config import Configuration
from ovos_utils.log import LOG
from poorman_handshake import HandShake, PasswordHandShake
from tornado import ioloop
from tornado.websocket import WebSocketHandler

from hivemind_core.transformers import MetadataTransformersService, UtteranceTransformersService


class ProtocolVersion(IntEnum):
Expand Down Expand Up @@ -253,11 +256,18 @@ class HiveMindListenerProtocol:
mycroft_bus_callback = None # slave asked to inject payload into mycroft bus
shared_bus_callback = None # passive sharing of slave device bus (info)

utterance_plugins: UtteranceTransformersService = None
metadata_plugins: MetadataTransformersService = None

def bind(self, websocket, bus):
websocket.protocol = self
self.internal_protocol = HiveMindListenerInternalProtocol(bus)
self.internal_protocol.register_bus_handlers()

config = Configuration().get("hivemind", {})
self.utterance_plugins = UtteranceTransformersService(bus, config=config)
self.metadata_plugins = MetadataTransformersService(bus, config=config)

def get_bus(self, client: HiveMindClientConnection):
# allow subclasses to use dedicated bus per client
return self.internal_protocol.bus
Expand Down Expand Up @@ -303,9 +313,9 @@ def handle_new_client(self, client: HiveMindClientConnection):
"max_protocol_version": max_version,
"binarize": True, # report we support the binarization scheme
"preshared_key": client.crypto_key
is not None, # do we have a pre-shared key (V0 proto)
is not None, # do we have a pre-shared key (V0 proto)
"password": client.pswd_handshake
is not None, # is password available (V1 proto, replaces pre-shared key)
is not None, # is password available (V1 proto, replaces pre-shared key)
"crypto_required": self.require_crypto, # do we allow unencrypted payloads
}
msg = HiveMessage(HiveMessageType.HANDSHAKE, payload)
Expand Down Expand Up @@ -381,7 +391,7 @@ def handle_message(self, message: HiveMessage, client: HiveMindClientConnection)

# HiveMind protocol messages - from slave -> master
def handle_unknown_message(
self, message: HiveMessage, client: HiveMindClientConnection
self, message: HiveMessage, client: HiveMindClientConnection
):
"""message handler for non default message types, subclasses can
handle their own types here
Expand All @@ -390,13 +400,13 @@ def handle_unknown_message(
"""

def handle_binary_message(
self, message: HiveMessage, client: HiveMindClientConnection
self, message: HiveMessage, client: HiveMindClientConnection
):
assert message.msg_type == HiveMessageType.BINARY
# TODO

def handle_handshake_message(
self, message: HiveMessage, client: HiveMindClientConnection
self, message: HiveMessage, client: HiveMindClientConnection
):
LOG.debug("handshake received, generating session key")
payload = message.payload
Expand Down Expand Up @@ -450,15 +460,33 @@ def handle_handshake_message(
msg = HiveMessage(HiveMessageType.HANDSHAKE, payload)
client.send(msg) # client can recreate crypto_key on his side now

def _handle_transformers(self, message: Message) -> Message:
"""
Pipe utterance through transformer plugins to get more metadata.
Utterances may be modified by any parser and context overwritten
"""
lang = get_message_lang(message) # per query lang or default Configuration lang
original = utterances = message.data.get('utterances', [])
message.context["lang"] = lang
utterances, message.context = self.utterance_plugins.transform(utterances, message.context)
if original != utterances:
message.data["utterances"] = utterances
LOG.debug(f"utterances transformed: {original} -> {utterances}")
message.context = self.metadata_plugins.transform(message.context)
return message

def handle_bus_message(
self, message: HiveMessage, client: HiveMindClientConnection
self, message: HiveMessage, client: HiveMindClientConnection
):
if message.payload.msg_type == "recognizer_loop:utterance":
message._payload = self._handle_transformers(message.payload).serialize()

self.handle_inject_mycroft_msg(message.payload, client)
if self.mycroft_bus_callback:
self.mycroft_bus_callback(message.payload)

def handle_broadcast_message(
self, message: HiveMessage, client: HiveMindClientConnection
self, message: HiveMessage, client: HiveMindClientConnection
):
"""
message (HiveMessage): HiveMind message object
Expand Down Expand Up @@ -492,7 +520,7 @@ def _unpack_message(self, message: HiveMessage, client: HiveMindClientConnection
return pload

def handle_propagate_message(
self, message: HiveMessage, client: HiveMindClientConnection
self, message: HiveMessage, client: HiveMindClientConnection
):
"""
message (HiveMessage): HiveMind message object
Expand Down Expand Up @@ -533,7 +561,7 @@ def handle_propagate_message(
bus.emit(message)

def handle_escalate_message(
self, message: HiveMessage, client: HiveMindClientConnection
self, message: HiveMessage, client: HiveMindClientConnection
):
"""
message (HiveMessage): HiveMind message object
Expand Down Expand Up @@ -578,7 +606,7 @@ def update_slave_session(self, message: Message, client: HiveMindClientConnectio
return message

def handle_inject_mycroft_msg(
self, message: Message, client: HiveMindClientConnection
self, message: Message, client: HiveMindClientConnection
):
"""
message (Message): mycroft bus message object
Expand Down
122 changes: 122 additions & 0 deletions hivemind_core/transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from typing import Optional, List

from ovos_plugin_manager.metadata_transformers import find_metadata_transformer_plugins
from ovos_plugin_manager.text_transformers import find_utterance_transformer_plugins

from ovos_utils.json_helper import merge_dict
from ovos_utils.log import LOG


class UtteranceTransformersService:

def __init__(self, bus, config=None):
self.config_core = config or {}
self.loaded_plugins = {}
self.has_loaded = False
self.bus = bus
self.config = self.config_core.get("utterance_transformers") or {}
self.load_plugins()

def load_plugins(self):
for plug_name, plug in find_utterance_transformer_plugins().items():
if plug_name in self.config:
# if disabled skip it
if not self.config[plug_name].get("active", True):
continue
try:
self.loaded_plugins[plug_name] = plug()
LOG.info(f"loaded utterance transformer plugin: {plug_name}")
except Exception as e:
LOG.error(e)
LOG.exception(f"Failed to load utterance transformer plugin: {plug_name}")

@property
def plugins(self):
"""
Return loaded transformers in priority order, such that modules with a
higher `priority` rank are called first and changes from lower ranked
transformers are applied last
A plugin of `priority` 1 will override any existing context keys and
will be the last to modify utterances`
"""
return sorted(self.loaded_plugins.values(),
key=lambda k: k.priority, reverse=True)

def shutdown(self):
for module in self.plugins:
try:
module.shutdown()
except:
pass

def transform(self, utterances: List[str], context: Optional[dict] = None):
context = context or {}

for module in self.plugins:
try:
utterances, data = module.transform(utterances, context)
_safe = {k:v for k,v in data.items() if k != "session"} # no leaking TTS/STT creds in logs
LOG.debug(f"{module.name}: {_safe}")
context = merge_dict(context, data)
except Exception as e:
LOG.warning(f"{module.name} transform exception: {e}")
return utterances, context


class MetadataTransformersService:

def __init__(self, bus, config=None):
self.config_core = config or {}
self.loaded_plugins = {}
self.has_loaded = False
self.bus = bus
self.config = self.config_core.get("metadata_transformers") or {}
self.load_plugins()

def load_plugins(self):
for plug_name, plug in find_metadata_transformer_plugins().items():
if plug_name in self.config:
# if disabled skip it
if not self.config[plug_name].get("active", True):
continue
try:
self.loaded_plugins[plug_name] = plug()
LOG.info(f"loaded metadata transformer plugin: {plug_name}")
except Exception as e:
LOG.error(e)
LOG.exception(f"Failed to load metadata transformer plugin: {plug_name}")

@property
def plugins(self):
"""
Return loaded transformers in priority order, such that modules with a
higher `priority` rank are called first and changes from lower ranked
transformers are applied last.
A plugin of `priority` 1 will override any existing context keys
"""
return sorted(self.loaded_plugins.values(),
key=lambda k: k.priority, reverse=True)

def shutdown(self):
for module in self.plugins:
try:
module.shutdown()
except:
pass

def transform(self, context: Optional[dict] = None):
context = context or {}

for module in self.plugins:
try:
data = module.transform(context)
_safe = {k:v for k,v in data.items() if k != "session"} # no leaking TTS/STT creds in logs
LOG.debug(f"{module.name}: {_safe}")
context = merge_dict(context, data)
except Exception as e:
LOG.warning(f"{module.name} transform exception: {e}")
return context


3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ tornado
ovos_utils>=0.0.33
pycryptodomex
HiveMind_presence>=0.0.2a3
ovos-bus-client>=0.0.6a5
ovos-bus-client>=0.0.6
ovos-plugin-manager
poorman_handshake>=0.1.0
click
click_default_group
Expand Down

0 comments on commit 8c61722

Please sign in to comment.