Skip to content

Commit

Permalink
Merge pull request RasaHQ#9990 from RasaHQ/tayfun/9922-fix-rasa-init-…
Browse files Browse the repository at this point in the history
…cache-directory-with-chdir

Fix training cache bug in `rasa init`
  • Loading branch information
Tayfun Sen authored Oct 29, 2021
2 parents b85594a + a83fbfe commit 691d857
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 9 deletions.
21 changes: 12 additions & 9 deletions rasa/cli/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def add_subparser(
scaffold_parser.set_defaults(func=run)


def print_train_or_instructions(args: argparse.Namespace, path: Text) -> None:
def print_train_or_instructions(args: argparse.Namespace) -> None:
"""Train a model if the user wants to."""
import questionary
import rasa
Expand All @@ -61,12 +61,12 @@ def print_train_or_instructions(args: argparse.Namespace, path: Text) -> None:

if should_train:
print_success("Training an initial model...")
config = os.path.join(path, DEFAULT_CONFIG_PATH)
training_files = os.path.join(path, DEFAULT_DATA_PATH)
domain = os.path.join(path, DEFAULT_DOMAIN_PATH)
output = os.path.join(path, create_output_path())

training_result = rasa.train(domain, config, training_files, output)
training_result = rasa.train(
DEFAULT_DOMAIN_PATH,
DEFAULT_CONFIG_PATH,
DEFAULT_DATA_PATH,
create_output_path(),
)
args.model = training_result.model

print_run_or_instructions(args)
Expand Down Expand Up @@ -125,12 +125,15 @@ def print_run_or_instructions(args: argparse.Namespace) -> None:


def init_project(args: argparse.Namespace, path: Text) -> None:
create_initial_project(path)
"""Inits project."""
os.chdir(path)
create_initial_project(".")
print("Created project directory at '{}'.".format(os.path.abspath(path)))
print_train_or_instructions(args, path)
print_train_or_instructions(args)


def create_initial_project(path: Text) -> None:
"""Creates directory structure and templates for initial project."""
from distutils.dir_util import copy_tree

copy_tree(scaffold_path(), path)
Expand Down
49 changes: 49 additions & 0 deletions tests/cli/test_rasa_init.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import argparse
import os
from pathlib import Path
from typing import Callable
from _pytest.pytester import RunResult
from _pytest.monkeypatch import MonkeyPatch

from rasa.cli import scaffold
from tests.conftest import enable_cache
from tests.core.channels.test_cmdline import mock_stdin


def test_init_using_init_dir_option(run_with_stdin: Callable[..., RunResult]):
Expand Down Expand Up @@ -50,3 +56,46 @@ def test_init_help(run: Callable[..., RunResult]):
def test_user_asked_to_train_model(run_with_stdin: Callable[..., RunResult]):
run_with_stdin("init", stdin=b"\nYN")
assert not os.path.exists("models")


def test_train_data_in_project_dir(monkeypatch: MonkeyPatch, tmp_path: Path):
"""Test cache directory placement.
Tests cache directories for training data are in project root, not
where `rasa init` is run.
"""
# We would like to test CLI but can't run it with popen because we want
# to be able to monkeypatch it. Solution is to call functions inside CLI
# module. Initial project folder should have been created before
# `init_project`, that's what we do here.
monkeypatch.chdir(tmp_path)
new_project_folder_path = tmp_path / "new-project-folder"
new_project_folder_path.mkdir()

# Simulate CLI run arguments.
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
scaffold.add_subparser(subparsers, parents=[])

args = parser.parse_args(["init", "--no-prompt",])

# Simple config which should train fast.
def mock_get_config(*args):
return {
"language": "en",
"pipeline": [{"name": "KeywordIntentClassifier"}],
"policies": [{"name": "RulePolicy"}],
"recipe": "default.v1",
}

monkeypatch.setattr(
"rasa.shared.importers.importer.CombinedDataImporter.get_config",
mock_get_config,
)
# Cache dir is auto patched to be a temp directory, this makes it
# go back to local project folder so we can test it is created correctly.
with enable_cache(Path(".rasa", "cache")):
mock_stdin([])
scaffold.init_project(args, str(new_project_folder_path))
assert os.getcwd() == str(new_project_folder_path)
assert os.path.exists(".rasa/cache")

0 comments on commit 691d857

Please sign in to comment.