Skip to content

feat(r): Added prompt_path argument to R package #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pkg-py/src/querychat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from querychat.querychat import init, mod_server as server, sidebar, system_prompt, mod_ui as ui
from querychat.querychat import init, sidebar, system_prompt
from querychat.querychat import mod_server as server
from querychat.querychat import mod_ui as ui

__all__ = ["init", "server", "sidebar", "ui", "system_prompt"]
__all__ = ["init", "server", "sidebar", "system_prompt", "ui"]
34 changes: 28 additions & 6 deletions pkg-py/src/querychat/querychat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
import chatlas
import chevron
import narwhals as nw
import pandas as pd
import sqlalchemy
from narwhals.typing import IntoFrame
from shiny import Inputs, Outputs, Session, module, reactive, ui

if TYPE_CHECKING:
Expand Down Expand Up @@ -145,6 +143,7 @@ def system_prompt(
data_description: Optional[str] = None,
extra_instructions: Optional[str] = None,
categorical_threshold: int = 10,
prompt_path: Optional[Path] = None,
) -> str:
"""
Create a system prompt for the chat model based on a data source's schema
Expand All @@ -162,6 +161,9 @@ def system_prompt(
categorical_threshold : int, default=10
Threshold for determining if a column is categorical based on number of
unique values
prompt_path
Optional `Path` to a custom prompt file. If not provided, the default
querychat template will be used.

Returns
-------
Expand All @@ -170,7 +172,11 @@ def system_prompt(

"""
# Read the prompt file
prompt_path = Path(__file__).parent / "prompt" / "prompt.md"
if prompt_path is None:
# Default to the prompt file in the same directory as this module
# This allows for easy customization by placing a different prompt.md file there
prompt_path = Path(__file__).parent / "prompt" / "prompt.md"

prompt_text = prompt_path.read_text()

return chevron.render(
Expand Down Expand Up @@ -226,11 +232,14 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str:
def init(
data_source: IntoFrame | sqlalchemy.Engine,
table_name: str,
/,
*,
greeting: Optional[str] = None,
data_description: Optional[str] = None,
extra_instructions: Optional[str] = None,
create_chat_callback: Optional[CreateChatCallback] = None,
prompt_path: Optional[Path] = None,
system_prompt_override: Optional[str] = None,
create_chat_callback: Optional[CreateChatCallback] = None,
) -> QueryChatConfig:
"""
Initialize querychat with any compliant data source.
Expand All @@ -251,10 +260,22 @@ def init(
Description of the data in plain text or Markdown
extra_instructions : str, optional
Additional instructions for the chat model
prompt_path : Path, optional
Path to a custom prompt file. If not provided, the default querychat
template will be used. This should be a Markdown file that contains the
system prompt template. The mustache template can use the following
variables:
- `{{db_engine}}`: The database engine used (e.g., "DuckDB")
- `{{schema}}`: The schema of the data source, generated by
`data_source.get_schema()`
- `{{data_description}}`: The optional data description provided
- `{{extra_instructions}}`: Any additional instructions provided
system_prompt_override : str, optional
A custom system prompt to use instead of the default. If provided,
`data_description`, `extra_instructions`, and `prompt_path` will be
silently ignored.
create_chat_callback : CreateChatCallback, optional
A function that creates a chat object
system_prompt_override : str, optional
A custom system prompt to use instead of the default

Returns
-------
Expand Down Expand Up @@ -289,6 +310,7 @@ def init(
data_source_obj,
data_description,
extra_instructions,
prompt_path=prompt_path,
)

# Default chat function if none provided
Expand Down
1 change: 1 addition & 0 deletions pkg-r/NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

export(df_to_schema)
export(querychat_init)
export(querychat_server)
export(querychat_sidebar)
Expand Down
5 changes: 5 additions & 0 deletions pkg-r/NEWS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# querychat (development version)

* Initial CRAN submission.

* Added `prompt_path` support for `querychat_system_prompt()`. (Thank you, @oacar! #37)
68 changes: 51 additions & 17 deletions pkg-r/R/prompt.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,32 @@
#' schema and optional additional context and instructions.
#'
#' @param df A data frame to generate schema information from.
#' @param name A string containing the name of the table in SQL queries.
#' @param data_description Optional description of the data, in plain text or Markdown format.
#' @param extra_instructions Optional additional instructions for the chat model, in plain text or Markdown format.
#' @param table_name A string containing the name of the table in SQL queries.
#' @param data_description Optional string in plain text or Markdown format, containing
#' a description of the data frame or any additional context that might be
#' helpful in understanding the data. This will be included in the system
#' prompt for the chat model.
#' @param extra_instructions Optional string in plain text or Markdown format, containing
#' any additional instructions for the chat model. These will be appended at
#' the end of the system prompt.
#' @param categorical_threshold The maximum number of unique values for a text column to be considered categorical.
#' @param prompt_path Optional string containing the path to a custom prompt file. If
#' `NULL`, the default prompt file in the package will be used. This file should
#' contain a whisker template for the system prompt, with placeholders for `{{schema}}`,
#' `{{data_description}}`, and `{{extra_instructions}}`.
#'
#' @return A string containing the system prompt for the chat model.
#'
#' @export
querychat_system_prompt <- function(
df,
name,
table_name,
data_description = NULL,
extra_instructions = NULL,
categorical_threshold = 10
categorical_threshold = 10,
prompt_path = system.file("prompt", "prompt.md", package = "querychat")
) {
schema <- df_to_schema(df, name, categorical_threshold)
schema <- df_to_schema(df, table_name, categorical_threshold)

if (!is.null(data_description)) {
data_description <- paste(data_description, collapse = "\n")
Expand All @@ -29,26 +39,50 @@ querychat_system_prompt <- function(
}

# Read the prompt file
prompt_path <- system.file("prompt", "prompt.md", package = "querychat")
if (is.null(prompt_path)) {
prompt_path <- system.file("prompt", "prompt.md", package = "querychat")
}
if (!file.exists(prompt_path)) {
stop("Prompt file not found at: ", prompt_path)
}
prompt_content <- readLines(prompt_path, warn = FALSE)
prompt_text <- paste(prompt_content, collapse = "\n")

whisker::whisker.render(
prompt_text,
list(
schema = schema,
data_description = data_description,
extra_instructions = extra_instructions
processed_template <-
whisker::whisker.render(
prompt_text,
list(
schema = schema,
data_description = data_description,
extra_instructions = extra_instructions
)
)
)

attr(processed_template, "table_name") <- table_name

processed_template
}

#' Generate a schema description from a data frame
#'
#' This function generates a schema description for a data frame, including
#' the column names, their types, and additional information such as ranges for
#' numeric columns and unique values for text columns.
#'
#' @param df A data frame to generate schema information from.
#' @param table_name A string containing the name of the table in SQL queries.
#' @param categorical_threshold The maximum number of unique values for a text column to be considered categorical.
#'
#' @return A string containing the schema description for the data frame.
#' The schema includes the table name, column names, their types, and additional
#' information such as ranges for numeric columns and unique values for text columns.
#' @export
df_to_schema <- function(
df,
name = deparse(substitute(df)),
categorical_threshold
table_name = deparse(substitute(df)),
categorical_threshold = 10
) {
schema <- c(paste("Table:", name), "Columns:")
schema <- c(paste("Table:", table_name), "Columns:")

column_info <- lapply(names(df), function(column) {
# Map R classes to SQL-like types
Expand Down
63 changes: 33 additions & 30 deletions pkg-r/R/querychat.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,23 @@
#' Shiny sessions in the R process.
#'
#' @param df A data frame.
#' @param tbl_name A string containing a valid table name for the data frame,
#' @param table_name A string containing a valid table name for the data frame,
#' that will appear in SQL queries. Ensure that it begins with a letter, and
#' contains only letters, numbers, and underscores. By default, querychat will
#' try to infer a table name using the name of the `df` argument.
#' @param greeting A string in Markdown format, containing the initial message
#' to display to the user upon first loading the chatbot. If not provided, the
#' LLM will be invoked at the start of the conversation to generate one.
#' @param data_description A string in plain text or Markdown format, containing
#' a description of the data frame or any additional context that might be
#' helpful in understanding the data. This will be included in the system
#' prompt for the chat model. If a `system_prompt` argument is provided, the
#' `data_description` argument will be ignored.
#' @param extra_instructions A string in plain text or Markdown format, containing
#' any additional instructions for the chat model. These will be appended at
#' the end of the system prompt. If a `system_prompt` argument is provided,
#' the `extra_instructions` argument will be ignored.
#' @param create_chat_func A function that takes a system prompt and returns a
#' chat object. The default uses `ellmer::chat_openai()`.
#' @param ... Additional arguments passed to the `querychat_system_prompt()`
#' function, such as `data_description`, `extra_instructions`, and
#' `prompt_path`. If a `system_prompt` argument is provided, the
#' `...` arguments will be silently ignored.
#' @param system_prompt A string containing the system prompt for the chat model.
#' The default uses `querychat_system_prompt()` to generate a generic prompt,
#' which you can enhance via the `data_description` and `extra_instructions`
#' arguments.
#' @param create_chat_func A function that takes a system prompt and returns a
#' chat object. The default uses `ellmer::chat_openai()`.
#'
#' @returns An object that can be passed to `querychat_server()` as the
#' `querychat_config` argument. By convention, this object should be named
Expand All @@ -34,45 +29,53 @@
#' @export
querychat_init <- function(
df,
tbl_name = deparse(substitute(df)),
...,
table_name = deparse(substitute(df)),
greeting = NULL,
data_description = NULL,
extra_instructions = NULL,
create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o"),
system_prompt = querychat_system_prompt(
df,
tbl_name,
data_description = data_description,
extra_instructions = extra_instructions
)
table_name,
# By default, pass through any params supplied to querychat_init()
...
),
create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o")
) {
is_tbl_name_ok <- is.character(tbl_name) &&
length(tbl_name) == 1 &&
grepl("^[a-zA-Z][a-zA-Z0-9_]*$", tbl_name, perl = TRUE)
if (!is_tbl_name_ok) {
if (missing(tbl_name)) {
is_table_name_ok <- is.character(table_name) &&
length(table_name) == 1 &&
grepl("^[a-zA-Z][a-zA-Z0-9_]*$", table_name, perl = TRUE)
if (!is_table_name_ok) {
if (missing(table_name)) {
rlang::abort(
"Unable to infer table name from `df` argument. Please specify `tbl_name` argument explicitly."
"Unable to infer table name from `df` argument. Please specify `table_name` argument explicitly."
)
} else {
rlang::abort(
"`tbl_name` argument must be a string containing a valid table name."
"`table_name` argument must be a string containing a valid table name."
)
}
}

force(df)
force(system_prompt)
force(system_prompt) # Have default `...` params evaluated
force(create_chat_func)

# TODO: Provide nicer looking errors here
stopifnot(
"df must be a data frame" = is.data.frame(df),
"tbl_name must be a string" = is.character(tbl_name),
"table_name must be a string" = is.character(table_name),
"system_prompt must be a string" = is.character(system_prompt),
"create_chat_func must be a function" = is.function(create_chat_func)
)

if ("table_name" %in% names(attributes(system_prompt))) {
# If available, be sure to use the `table_name` argument to `querychat_init()`
# matches the one supplied to the system prompt
if (table_name != attr(system_prompt, "table_name")) {
rlang::abort(
"`querychat_init(table_name=)` must match system prompt `table_name` supplied to `querychat_system_prompt()`."
)
}
}
if (!is.null(greeting)) {
greeting <- paste(collapse = "\n", greeting)
} else {
Expand All @@ -83,7 +86,7 @@ querychat_init <- function(
}

conn <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:")
duckdb::duckdb_register(conn, tbl_name, df, experimental = FALSE)
duckdb::duckdb_register(conn, table_name, df, experimental = FALSE)
shiny::onStop(function() DBI::dbDisconnect(conn))

structure(
Expand Down
29 changes: 29 additions & 0 deletions pkg-r/man/df_to_schema.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading