Skip to content

Commit

Permalink
introduced undo to online learning
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Sep 11, 2018
1 parent 949463c commit ad95492
Show file tree
Hide file tree
Showing 21 changed files with 197 additions and 171 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ examples/moodbot/*.png
examples/moodbot/errors.json
docs/key
docs/key.pub
failed_stories.md
3 changes: 1 addition & 2 deletions examples/concertbot/train_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import logging

from rasa_core import utils, train
from rasa_core import utils, train, run
from rasa_core.training import online

logger = logging.getLogger(__name__)
Expand All @@ -15,7 +15,6 @@ def train_agent():
return train.train_dialogue_model(domain_file="domain.yml",
stories_file="data/stories.md",
output_path="models/dialogue",
endpoints="endpoints.yml",
max_history=2,
kwargs={"batch_size": 50,
"epochs": 200,
Expand Down
5 changes: 2 additions & 3 deletions rasa_core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,9 @@ def __init__(self, name, action_endpoint):
def _action_call_format(self, tracker, domain):
# type: (DialogueStateTracker, Domain) -> Dict[Text, Any]
"""Create the request json send to the action server."""
from rasa_core.trackers import EventVerbosity

tracker_state = tracker.current_state(
should_include_events=True,
should_ignore_restarts=True)
tracker_state = tracker.current_state(EventVerbosity.ALL)

return {
"next_action": self._name,
Expand Down
4 changes: 2 additions & 2 deletions rasa_core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from rasa_core.policies.memoization import MemoizationPolicy
from rasa_core.processor import MessageProcessor
from rasa_core.tracker_store import InMemoryTrackerStore, TrackerStore
from rasa_core.trackers import DialogueStateTracker
from rasa_core.trackers import DialogueStateTracker, EventVerbosity
from rasa_core.utils import EndpointConfig
from rasa_nlu.utils import is_url

Expand Down Expand Up @@ -328,7 +328,7 @@ def log_message(

processor = self._create_processor(message_preprocessor)
tracker = processor.log_message(message)
return tracker.current_state(should_include_events=True)
return tracker.current_state(EventVerbosity.AFTER_RESTART)

def execute_action(
self,
Expand Down
5 changes: 2 additions & 3 deletions rasa_core/nlg/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from rasa_core.constants import DEFAULT_REQUEST_TIMEOUT
from rasa_core.nlg.generator import NaturalLanguageGenerator
from rasa_core.trackers import DialogueStateTracker
from rasa_core.trackers import DialogueStateTracker, EventVerbosity
from rasa_core.utils import EndpointConfig

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -75,8 +75,7 @@ def nlg_request_format(template_name, tracker, output_channel, **kwargs):
# type: (Text, DialogueStateTracker, Text, Any) -> Dict[Text, Any]
"""Create the json body for the NLG json body for the request."""

tracker_state = tracker.current_state(should_include_events=True,
should_ignore_restarts=True)
tracker_state = tracker.current_state(EventVerbosity.ALL)

return {
"template": template_name,
Expand Down
4 changes: 2 additions & 2 deletions rasa_core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from rasa_core.nlg import NaturalLanguageGenerator
from rasa_core.policies.ensemble import PolicyEnsemble
from rasa_core.tracker_store import TrackerStore
from rasa_core.trackers import DialogueStateTracker
from rasa_core.trackers import DialogueStateTracker, EventVerbosity
from rasa_core.utils import EndpointConfig

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -107,7 +107,7 @@ def predict_next(self, sender_id):
return {
"scores": scores,
"policy": policy,
"tracker": tracker.current_state(should_include_events=True)
"tracker": tracker.current_state(EventVerbosity.AFTER_RESTART)
}

def log_message(self, message):
Expand Down
15 changes: 6 additions & 9 deletions rasa_core/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from rasa_core import utils
from rasa_core.domain import Domain
from rasa_core.events import Event
from rasa_core.trackers import DialogueStateTracker
from rasa_core.trackers import DialogueStateTracker, EventVerbosity
from rasa_core.utils import EndpointConfig

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -48,30 +48,27 @@ def clients(self):
def tracker(self,
sender_id, # type: Text
domain, # type: Domain
should_ignore_restarts=False, # type: bool
include_events=True, # type: bool
event_verbosity=EventVerbosity.ALL, # type: EventVerbosity
until=None # type: Optional[int]
):
"""Retrieve and recreate a tracker fetched from the remote instance."""

tracker_json = self.tracker_json(
sender_id, should_ignore_restarts,
include_events, until)
sender_id, event_verbosity, until)

tracker = DialogueStateTracker.from_dict(
sender_id, tracker_json.get("events", []), domain.slots)
return tracker

def tracker_json(self,
sender_id, # type: Text
should_ignore_restarts=True, # type: bool
include_events=True, # type: bool
event_verbosity=EventVerbosity.ALL, # type: EventVerbosity
until=None # type: Optional[int]
):
"""Retrieve a tracker's json representation from remote instance."""

url = "/conversations/{}/tracker?ignore_restarts={}&events={}".format(
sender_id, should_ignore_restarts, include_events)
url = "/conversations/{}/tracker?events={}".format(
sender_id, event_verbosity.name)
if until:
url += "&until={}".format(until)

Expand Down
3 changes: 2 additions & 1 deletion rasa_core/restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from rasa_core.interpreter import NaturalLanguageInterpreter
from rasa_core.run import load_agent
from rasa_core.trackers import DialogueStateTracker
from rasa_core.utils import AvailableEndpoints

logger = logging.getLogger() # get the root logger

Expand Down Expand Up @@ -122,7 +123,7 @@ def serve_application(model_directory, # type: Text
):
from rasa_core import run

_endpoints = run.read_endpoints(endpoints)
_endpoints = AvailableEndpoints.read_endpoints(endpoints)

nlu = NaturalLanguageInterpreter.create(nlu_model, _endpoints.nlu)

Expand Down
22 changes: 2 additions & 20 deletions rasa_core/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,10 @@
BUILTIN_CHANNELS)
from rasa_core.interpreter import (
NaturalLanguageInterpreter)
from rasa_core.utils import read_yaml_file
from rasa_core.utils import read_yaml_file, AvailableEndpoints

logger = logging.getLogger() # get the root logger

AvailableEndpoints = namedtuple('AvailableEndpoints', 'nlg '
'nlu '
'action '
'model')


def create_argument_parser():
"""Parse all the command line arguments for the run script."""
Expand Down Expand Up @@ -90,19 +85,6 @@ def create_argument_parser():
return parser


def read_endpoints(endpoint_file):
nlg = utils.read_endpoint_config(endpoint_file,
endpoint_type="nlg")
nlu = utils.read_endpoint_config(endpoint_file,
endpoint_type="nlu")
action = utils.read_endpoint_config(endpoint_file,
endpoint_type="action_endpoint")
model = utils.read_endpoint_config(endpoint_file,
endpoint_type="models")

return AvailableEndpoints(nlg, nlu, action, model)


def _create_external_channels(channel, credentials_file):
# type: (Optional[Text], Optional[Text]) -> List[InputChannel]

Expand Down Expand Up @@ -248,7 +230,7 @@ def load_agent(core_model, interpreter, endpoints,

logger.info("Rasa process starting")

_endpoints = read_endpoints(cmdline_args.endpoints)
_endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints)
_interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu,
_endpoints.nlu)
_agent = load_agent(cmdline_args.core,
Expand Down
34 changes: 22 additions & 12 deletions rasa_core/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from rasa_core.events import Event
from rasa_core.interpreter import NaturalLanguageInterpreter
from rasa_core.policies import PolicyEnsemble
from rasa_core.trackers import DialogueStateTracker
from rasa_core.trackers import DialogueStateTracker, EventVerbosity
from rasa_core.utils import AvailableEndpoints
from rasa_core.version import __version__
from rasa_core.channels import UserMessage

Expand Down Expand Up @@ -129,7 +130,7 @@ def execute_action(sender_id):

# retrieve tracker and set to requested state
tracker = agent.tracker_store.get_or_create_tracker(sender_id)
state = tracker.current_state(should_include_events=True)
state = tracker.current_state(EventVerbosity.AFTER_RESTART)
return jsonify({"tracker": state,
"messages": out.messages})

Expand Down Expand Up @@ -162,7 +163,7 @@ def append_event(sender_id):
logger.warning(
"Append event called, but could not extract a "
"valid event. Request JSON: {}".format(request_params))
return jsonify(tracker.current_state(should_include_events=True))
return jsonify(tracker.current_state(EventVerbosity.AFTER_RESTART))

@app.route("/conversations/<sender_id>/tracker/events",
methods=['PUT'])
Expand All @@ -178,7 +179,7 @@ def replace_events(sender_id):
agent.domain.slots)
# will override an existing tracker with the same id!
agent.tracker_store.save(tracker)
return jsonify(tracker.current_state(should_include_events=True))
return jsonify(tracker.current_state(EventVerbosity.AFTER_RESTART))

@app.route("/conversations",
methods=['GET', 'OPTIONS'])
Expand All @@ -202,10 +203,20 @@ def retrieve_tracker(sender_id):
status=503)

# parameters
should_ignore_restarts = utils.bool_arg('ignore_restarts',
default=False)
should_include_events = utils.bool_arg('events',
default=True)
if request.args.get('ignore_restarts') is not None:
return Response("Parameter 'ignore_restarts' is not supported "
"anymore. use `events` instead.",
status=404)

event_verbosity_str = request.args.get('events', default="ALL").upper()
try:
verbosity = EventVerbosity[event_verbosity_str]
except KeyError:
enum_values = ", ".join([e.name for e in EventVerbosity])
return Response("Invalid parameter value for 'events'. Should be "
"one of {}".format(enum_values),
status=404)

until_time = request.args.get('until', None)

# retrieve tracker and set to requested state
Expand All @@ -219,9 +230,8 @@ def retrieve_tracker(sender_id):
tracker = tracker.travel_back_in_time(float(until_time))

# dump and return tracker
state = tracker.current_state(
should_include_events=should_include_events,
should_ignore_restarts=should_ignore_restarts)

state = tracker.current_state(verbosity)
return jsonify(state)

@app.route("/conversations/<sender_id>/respond",
Expand Down Expand Up @@ -452,7 +462,7 @@ def tracker_predict():

logger.info("Rasa process starting")

_endpoints = run.read_endpoints(cmdline_args.endpoints)
_endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints)
_interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu,
_endpoints.nlu)
_agent = run.load_agent(cmdline_args.core,
Expand Down
7 changes: 4 additions & 3 deletions rasa_core/tracker_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

from rasa_core.actions.action import ACTION_LISTEN_NAME
from rasa_core.broker import EventChannel
from rasa_core.trackers import DialogueStateTracker, ActionExecuted
from rasa_core.trackers import (
DialogueStateTracker, ActionExecuted,
EventVerbosity)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -176,8 +178,7 @@ def save(self, tracker, timeout=None):
if self.event_broker:
self.stream_events(tracker)

state = tracker.current_state(should_include_events=True,
should_ignore_restarts=True)
state = tracker.current_state(EventVerbosity.ALL)

self.conversations.update_one(
{"sender_id": tracker.sender_id},
Expand Down
38 changes: 28 additions & 10 deletions rasa_core/trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import deque

import typing
from enum import Enum
from typing import Generator, Dict, Text, Any, Optional, Iterator
from typing import List

Expand All @@ -28,6 +29,25 @@
from rasa_core.domain import Domain


class EventVerbosity(Enum):
"""Filter on which events to include in tracker dumps."""

# no events will be included
NONE = 1

# all events, that contribute to the trackers state are included
# these are all you need to reconstruct the tracker state
APPLIED = 2

# include even more events, in this case everything that comes
# after the most recent restart event. this will also include
# utterances that got reverted and actions that got undone.
AFTER_RESTART = 3

# include every logged event
ALL = 4


class DialogueStateTracker(object):
"""Maintains the state of a conversation."""

Expand Down Expand Up @@ -85,18 +105,16 @@ def __init__(self, sender_id, slots,
###
# Public tracker interface
###
def current_state(self,
should_include_events=False,
should_ignore_restarts=False):
# type: (bool, bool) -> Dict[Text, Any]
def current_state(self, event_verbosity=EventVerbosity.NONE):
# type: (EventVerbosity) -> Dict[Text, Any]
"""Return the current tracker state as an object."""

if should_include_events:
if should_ignore_restarts:
es = self.events
else:
es = self.events_after_latest_restart()
evts = [e.as_dict() for e in es]
if event_verbosity == EventVerbosity.ALL:
evts = [e.as_dict() for e in self.events]
elif event_verbosity == EventVerbosity.AFTER_RESTART:
evts = [e.as_dict() for e in self.events_after_latest_restart()]
elif event_verbosity == EventVerbosity.APPLIED:
evts = [e.as_dict() for e in self.applied_events()]
else:
evts = None

Expand Down
7 changes: 4 additions & 3 deletions rasa_core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import argparse

from rasa_core import utils, run
from rasa_core import utils
from rasa_core.agent import Agent
from rasa_core.constants import (
DEFAULT_NLU_FALLBACK_THRESHOLD,
Expand All @@ -18,6 +18,7 @@
from rasa_core.policies import FallbackPolicy
from rasa_core.policies.keras_policy import KerasPolicy
from rasa_core.policies.memoization import MemoizationPolicy
from rasa_core.run import AvailableEndpoints
from rasa_core.training import online


Expand Down Expand Up @@ -129,7 +130,7 @@ def create_argument_parser():

def train_dialogue_model(domain_file, stories_file, output_path,
interpreter=None,
endpoints=None,
endpoints=AvailableEndpoints(),
max_history=None,
dump_flattened_stories=False,
kwargs=None):
Expand Down Expand Up @@ -199,7 +200,7 @@ def train_dialogue_model(domain_file, stories_file, output_path,
else:
stories = cmdline_args.stories

_endpoints = run.read_endpoints(cmdline_args.endpoints)
_endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints)
_interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu,
_endpoints.nlu)

Expand Down
Loading

0 comments on commit ad95492

Please sign in to comment.