Skip to content

Commit

Permalink
Use prompt_toolkit for user input
Browse files Browse the repository at this point in the history
Prompt's from rich are limited (e.g. arrow keys do not work, etc.)
  • Loading branch information
s4n-cz committed Nov 15, 2024
1 parent 7e01f9f commit 30ab5d3
Show file tree
Hide file tree
Showing 11 changed files with 92 additions and 65 deletions.
8 changes: 6 additions & 2 deletions sereto/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import click
import keyring
from prompt_toolkit import prompt
from pydantic import FilePath, validate_call

from sereto.cli.commands import sereto_ls, sereto_repl
Expand Down Expand Up @@ -72,7 +73,9 @@ def new(settings: Settings, report_id: TypeProjectId) -> None:
settings: The settings object containing the tool's global configuration.
report_id: The ID of the report to be created.
"""
new_report(settings=settings, report_id=report_id)
Console().print("[cyan]We will ask you a few questions to set up the new report.\n")
name = prompt("Name of the report: ")
new_report(settings=settings, id=report_id, name=name)


@cli.command()
Expand Down Expand Up @@ -372,7 +375,7 @@ def config_targets_delete(project: Project, index: int) -> None:
project: Report's project representation.
index: The index of the target to be deleted. You can obtain the index by running `sereto config targets show`.
"""
delete_targets_config(project=project, index=index)
delete_targets_config(project=project, index=index, interactive=True)


@config_targets.command(name="show")
Expand Down Expand Up @@ -450,6 +453,7 @@ def finding_add(project: Project, target: str | None, format: str, name: str) ->
target_selector=target,
format=format,
name=name,
interactive=True,
)


Expand Down
3 changes: 1 addition & 2 deletions sereto/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ def repl_cd(settings: Settings, project_id: TypeProjectId | Literal["-"]) -> Non
Args:
settings: The Settings object.
cmd: The user input command.
wd: The WorkingDir object.
project_id: The ID of the project to switch to. Use '-' to go back to the previous working directory.
Raises:
SeretoValueError: If the report ID is invalid.
Expand Down
26 changes: 14 additions & 12 deletions sereto/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import shutil

import click
from prompt_toolkit.shortcuts import yes_no_dialog
from pydantic import TypeAdapter, validate_call
from rich import box
from rich.prompt import Confirm
from rich.table import Table

from sereto.cli.date import prompt_user_for_date
Expand Down Expand Up @@ -111,7 +111,7 @@ def add_dates_config(project: Project) -> None:
dates: list[Date] = cfg.dates if len(cfg.updates) == 0 else cfg.updates[-1].dates

# Add a new date
date_type: DateType = load_enum(enum=DateType, prompt="Type")
date_type: DateType = load_enum(enum=DateType, message="Type:")
new_date = prompt_user_for_date(date_type=date_type)
dates.append(Date(type=date_type, date=new_date))

Expand Down Expand Up @@ -213,7 +213,7 @@ def add_people_config(project: Project) -> None:
people: list[Person] = cfg.people if len(cfg.updates) == 0 else cfg.updates[-1].people

# Add a new person
person_type: PersonType = load_enum(enum=PersonType, prompt="Type")
person_type: PersonType = load_enum(enum=PersonType, message="Type:")
new_person = prompt_user_for_person(person_type=person_type)
people.append(new_person)

Expand Down Expand Up @@ -325,33 +325,35 @@ def add_targets_config(project: Project) -> None:


@validate_call
def delete_targets_config(project: Project, index: int) -> None:
def delete_targets_config(project: Project, index: int, interactive: bool = False) -> None:
"""Delete target from the configuration by its index.
Args:
project: Report's project representation.
index: Index to item which should be deleted. First item is 1.
interactive: Whether to ask for confirmations.
"""
cfg = project.config
targets: list[Target] = cfg.targets if len(cfg.updates) == 0 else cfg.updates[-1].targets

# Validate the index, convert to 0-based
index -= 1
if not 0 <= index <= len(targets) - 1:
raise SeretoValueError("invalid index, not in allowed range")
del targets[index]

# Write the configuration
# Extract the filesystem path before deleting the values
target_path = targets[index].path

# Delete target from the config
del targets[index]
cfg.dump_json(file=project.get_config_path())

# Delete the target from the filesystem
target_path = targets[index].path
if (
target_path is not None
and target_path.is_dir()
and Confirm.ask(
f'[yellow]Delete "{target_path}" from the filesystem?',
console=Console(),
default=False,
)
and interactive
and yes_no_dialog(title="Confirm", text=f"Delete '{target_path}' from the filesystem?").run()
):
shutil.rmtree(target_path)

Expand Down
18 changes: 9 additions & 9 deletions sereto/cli/date.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from prompt_toolkit import prompt
from pydantic import validate_call
from rich.prompt import Prompt

from sereto.cli.utils import Console
from sereto.models.date import TYPES_WITH_ALLOWED_RANGE, DateRange, DateType, SeretoDate


@validate_call
def _prompt_date(prompt: str, default: SeretoDate | None = None) -> SeretoDate | None:
def _prompt_date(message: str, default: SeretoDate | None = None) -> SeretoDate | None:
"""Interactively prompt the user for a date in the format DD-Mmm-YYYY.
Args:
message: The message to display to the user.
default: The default date, which the user can easily accept. Defaults to None.
Returns:
The data as provided by the user, or None if the input was invalid.
SeretoDate if correct input was provided, None otherwise
"""
if default is not None:
user_input: str = Prompt.ask(prompt, console=Console(), default=str(default))
else:
user_input = Prompt.ask(prompt, console=Console())
user_input = prompt(message) if default is None else prompt(message, default=str(default))

try:
return SeretoDate.from_str(user_input)
Expand All @@ -42,13 +42,13 @@ def prompt_user_for_date(date_type: DateType) -> SeretoDate | DateRange:
while True:
# Prompt user for the start date
prompt: str = f"Date{' start' if allow_range else ''} (DD-Mmm-YYYY)"
if (start_date := _prompt_date(prompt)) is None:
if (start_date := _prompt_date(f"{prompt}: ")) is None:
Console().print("[red]Invalid input, try again\n")
continue

# Prompt user for the end date, if the date type allows it
if allow_range:
if (end_date := _prompt_date("Date end (DD-Mmm-YYYY)", default=start_date)) is None:
if (end_date := _prompt_date("Date end (DD-Mmm-YYYY): ", default=start_date)) is None:
Console().print("[red]Invalid input, try again\n")
continue
else:
Expand Down
20 changes: 13 additions & 7 deletions sereto/cli/person.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from prompt_toolkit import prompt
from pydantic import EmailStr, TypeAdapter, ValidationError, validate_call
from rich.prompt import Prompt

from sereto.cli.utils import Console
from sereto.models.person import Person, PersonType
Expand All @@ -15,16 +15,22 @@ def prompt_user_for_person(person_type: PersonType) -> Person:
Returns:
The person as provided by the user.
"""
name: str | None = Prompt.ask("Name", console=Console(), default=None)
business_unit: str | None = Prompt.ask("Business unit", console=Console(), default=None)
name = prompt("Name: ")
business_unit = prompt("Business unit: ")
while True:
try:
e: str | None = Prompt.ask("Email", console=Console(), default=None)
e = prompt("Email: ")
ta: TypeAdapter[EmailStr] = TypeAdapter(EmailStr) # hack for mypy
email: EmailStr | None = ta.validate_python(e) if e is not None else None
email: EmailStr | None = ta.validate_python(e) if len(e) > 0 else None
break
except ValidationError:
Console().print("[red]Please enter valid email address")
role: str | None = Prompt.ask("Role", console=Console(), default=None)
role = prompt("Role: ")

return Person(type=person_type, name=name, business_unit=business_unit, email=email, role=role)
return Person(
type=person_type,
name=name if len(name) > 0 else None,
business_unit=business_unit if len(business_unit) > 0 else None,
email=email,
role=role if len(role) > 0 else None,
)
11 changes: 8 additions & 3 deletions sereto/cli/target.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import click
from rich.prompt import Prompt
from prompt_toolkit import prompt
from prompt_toolkit.shortcuts import radiolist_dialog

from sereto.cli.utils import Console
from sereto.exceptions import SeretoRuntimeError
Expand All @@ -17,8 +18,12 @@ def prompt_user_for_target(settings: Settings) -> Target:
The target as provided by the user.
"""
Console().line()
category = Prompt.ask("Category", choices=list(settings.categories), console=Console())
name = Prompt.ask("Name", console=Console())
category = radiolist_dialog(
title="New target",
text="Category:",
values=[(c, c.upper()) for c in list(settings.categories)],
).run()
name = prompt("Name: ")

match category:
case "dast":
Expand Down
11 changes: 8 additions & 3 deletions sereto/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import TypeVar

import click
from prompt_toolkit.shortcuts import radiolist_dialog
from rich.console import Console as RichConsole
from rich.prompt import Prompt

from sereto.cli.aliases import cli_aliases
from sereto.singleton import Singleton
Expand Down Expand Up @@ -55,9 +55,14 @@ def resolve_command(
EnumType = TypeVar("EnumType", bound=Enum)


def load_enum(enum: type[EnumType], prompt: str) -> EnumType:
def load_enum(enum: type[EnumType], message: str) -> EnumType:
"""Let user select a value from enum."""
choice = Prompt.ask(prompt=prompt, choices=[e.value for e in enum])
choice = radiolist_dialog(
title="Select value",
text=message,
values=[(e.name, e.value) for e in enum],
).run()

return enum(choice)


Expand Down
11 changes: 6 additions & 5 deletions sereto/finding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import frontmatter # type: ignore[import-untyped]
from prompt_toolkit.shortcuts import yes_no_dialog
from pydantic import ValidationError, validate_call
from rich.prompt import Confirm
from rich.table import Table
from ruamel.yaml.comments import CommentedMap, CommentedSeq

Expand All @@ -22,6 +22,7 @@ def add_finding(
target_selector: str | None,
format: str,
name: str,
interactive: bool = False,
) -> None:
target = project.select_target(selector=target_selector)

Expand All @@ -40,10 +41,10 @@ def add_finding(
finding_dir.mkdir(exist_ok=True)
dst_path = finding_dir / f"{name}{project.config.last_version().path_suffix}.{format}.j2"

if dst_path.is_file() and not Confirm.ask(
f'[yellow]Destination "{dst_path}" exists. Overwrite?',
console=Console(),
default=False,
# Destination file exists and we cannot proceed
if dst_path.is_file() and (
not interactive
or not yes_no_dialog(title="Warning", text=f"Destination '{dst_path}' exists. Overwrite?").run()
):
raise SeretoRuntimeError("cannot proceed")

Expand Down
17 changes: 7 additions & 10 deletions sereto/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from shutil import copy2, copytree

from pydantic import DirectoryPath, validate_call
from rich.prompt import Prompt

from sereto.cli.utils import Console
from sereto.exceptions import SeretoPathError
Expand Down Expand Up @@ -53,32 +52,30 @@ def copy_skel(templates: DirectoryPath, dst: DirectoryPath, overwrite: bool = Fa


@validate_call
def new_report(settings: Settings, report_id: TypeProjectId) -> None:
def new_report(settings: Settings, id: TypeProjectId, name: str) -> None:
"""Generates a new report with the specified ID.
Args:
settings: Global settings.
report_id: The ID of the new report. This should be a string that uniquely identifies the report.
id: The ID of the new report. This should be a string that uniquely identifies the report.
name: The name of the new report.
Raises:
SeretoValueError: If a report with the specified ID already exists in the `reports` directory.
"""
Console().log(f"Generating a new report with ID {report_id!r}")
Console().log(f"Generating a new report with ID {id!r}")

if (new_path := (settings.reports_path / report_id)).exists():
if (new_path := (settings.reports_path / id)).exists():
raise SeretoPathError("report with specified ID already exists")
else:
new_path.mkdir()

Console().print("[cyan]We will ask you a few questions to set up the new report.\n")

report_name: str = Prompt.ask("Name of the report", console=Console())
sereto_ver = importlib.metadata.version("sereto")

cfg = Config(
sereto_version=SeretoVersion.from_str(sereto_ver),
id=report_id,
name=report_name,
id=id,
name=name,
report_version=ReportVersion.from_str("v1.0"),
)

Expand Down
17 changes: 7 additions & 10 deletions sereto/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from typing import TypeVar

from click import get_current_context
from prompt_toolkit import prompt
from prompt_toolkit.shortcuts import yes_no_dialog
from pydantic import validate_call
from rich.prompt import Confirm, Prompt
from typing_extensions import ParamSpec

from sereto.cli.utils import Console
Expand All @@ -26,9 +27,9 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return wrapper


def _ask_for_dirpath(prompt: str) -> Path:
def _ask_for_dirpath(message: str) -> Path:
while True:
input = Prompt.ask(prompt, console=Console())
input = prompt(f"{message}: ")
path = Path(input).resolve()

if path.exists():
Expand All @@ -38,11 +39,7 @@ def _ask_for_dirpath(prompt: str) -> Path:
Console().print("the provided path is not a directory")
continue
else:
if Confirm.ask(
f'[yellow]Directory "{path}" does not exist. Create?',
console=Console(),
default=True,
):
if yes_no_dialog(title="Warning", text=f"Directory '{path}' does not exist. Create?").run():
path.mkdir(parents=True)
return path

Expand All @@ -53,8 +50,8 @@ def load_settings_function() -> Settings:
else:
Console().print("[cyan]It seems like this is the first time you're running the tool. Let's set it up!\n")

reports_path = _ask_for_dirpath(":open_file_folder: Enter the path to the reports directory")
templates_path = _ask_for_dirpath(":open_file_folder: Enter the path to the templates directory")
reports_path = _ask_for_dirpath("Enter the path to the reports directory")
templates_path = _ask_for_dirpath("Enter the path to the templates directory")

Console().print("\nThank you! The minimal setup is complete.")
Console().print(
Expand Down
Loading

0 comments on commit 30ab5d3

Please sign in to comment.