diff --git a/R/ollama.R b/R/ollama.R index c380a71..3521c1e 100644 --- a/R/ollama.R +++ b/R/ollama.R @@ -64,7 +64,7 @@ create_request <- function(endpoint, host = NULL) { #' @param stream Enable response streaming. Default is FALSE. #' @param raw If TRUE, no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API. Default is FALSE. #' @param keep_alive The time to keep the connection alive. Default is "5m" (5 minutes). -#' @param output A character vector of the output format. Default is "resp". Options are "resp", "jsonlist", "raw", "df", "text". +#' @param output A character vector of the output format. Default is "resp". Options are "resp", "jsonlist", "raw", "df", "text", "req" (httr2_request object). #' @param endpoint The endpoint to generate the completion. Default is "/api/generate". #' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL. #' @param ... Additional options to pass to the model. @@ -80,10 +80,10 @@ create_request <- function(endpoint, host = NULL) { #' generate("llama3", "The sky is...", stream = TRUE, output = "text") #' generate("llama3", "The sky is...", stream = TRUE, output = "text", temperature = 2.0) #' generate("llama3", "The sky is...", stream = FALSE, output = "jsonlist") -generate <- function(model, prompt, suffix = "", images = "", system = "", template = "", context = list(), stream = FALSE, raw = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text"), endpoint = "/api/generate", host = NULL, ...) { +generate <- function(model, prompt, suffix = "", images = "", system = "", template = "", context = list(), stream = FALSE, raw = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req"), endpoint = "/api/generate", host = NULL, ...) { output <- output[1] - if (!output %in% c("df", "resp", "jsonlist", "raw", "text")) { - stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text'") + if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req")) { + stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text', 'req'") } req <- create_request(endpoint, host) @@ -118,6 +118,10 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ req <- httr2::req_body_json(req, body_json, stream = stream) + if (output == "req") { + return(req) + } + if (!stream) { tryCatch( { @@ -160,7 +164,7 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ #' @param tools Tools for the model to use if supported. Requires stream = FALSE. Default is an empty list. #' @param stream Enable response streaming. Default is FALSE. #' @param keep_alive The duration to keep the connection alive. Default is "5m". -#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text". +#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text", "req" (httr2_request object). #' @param endpoint The endpoint to chat with the model. Default is "/api/chat". #' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL. #' @param ... Additional options to pass to the model. @@ -191,9 +195,9 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ #' list(role = "user", content = "List all the previous messages.") #' ) #' chat("llama3", messages, stream = TRUE) -chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text"), endpoint = "/api/chat", host = NULL, ...) { +chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req"), endpoint = "/api/chat", host = NULL, ...) { output <- output[1] - if (!output %in% c("df", "resp", "jsonlist", "raw", "text")) { + if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req")) { stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text'") } @@ -218,6 +222,9 @@ chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = " } req <- httr2::req_body_json(req, body_json, stream = stream) + if (output == "req") { + return(req) + } if (!stream) { tryCatch( diff --git a/README.Rmd b/README.Rmd index ce5937f..54fa3fe 100644 --- a/README.Rmd +++ b/README.Rmd @@ -284,3 +284,87 @@ messages[[4]] # get 2nd message # delete a message at a specific index/position (2nd position in the example below) messages <- delete_message(messages, 2) ``` + +## Advanced usage + +### Parallel requests + +For the `generate()` and `chat()` endpoints/functions, you can make parallel requests with the `req_perform_parallel` function from the `httr2` library. You need to specify `output = 'req'` in the function so the functions return `httr2_request` objects instead of `httr2_response` objects. + +```{r eval=FALSE} +library(httr2) + +prompt <- "Tell me a 5-word story" + +# create 5 httr2_request objects that generate a response to the same prompt +reqs <- lapply(1:5, function(r) generate("llama3.1", prompt, output = "req")) + +# make parallel requests and get response +resps <- req_perform_parallel(reqs) # list of httr2_request objects + +# process the responses +sapply(resps, resp_process, "text") # get responses as text +# [1] "She found him in Paris." "She found the key upstairs." +# [3] "She found her long-lost sister." "She found love on Mars." +# [5] "She found the diamond ring." + +``` + +Example sentiment analysis with parallel requests with `generate()` function + +```{r eval=FALSE} +library(httr2) +library(glue) +library(dplyr) + +# text to classify +texts <- c('I love this product', 'I hate this product', 'I am neutral about this product') + +# create httr2_request objects for each text, using the same system prompt +reqs <- lapply(texts, function(text) { + prompt <- glue("Your only task/role is to evaluate the sentiment of product reviews, and your response should be one of the following:'positive', 'negative', or 'other'. Product review: {text}") + generate("llama3.1", prompt, output = "req") +}) + +# make parallel requests and get response +resps <- req_perform_parallel(reqs) # list of httr2_request objects + +# process the responses +sapply(resps, resp_process, "text") # get responses as text +# [1] "Positive" "Negative." +# [3] "'neutral' translates to... 'other'." + + +``` + +Example sentiment analysis with parallel requests with `chat()` function + +```{r eval=FALSE} +library(httr2) +library(dplyr) + +# text to classify +texts <- c('I love this product', 'I hate this product', 'I am neutral about this product') + +# create system prompt +chat_history <- create_message("Your only task/role is to evaluate the sentiment of product reviews provided by the user. Your response should simply be 'positive', 'negative', or 'other'.", "system") + +# create httr2_request objects for each text, using the same system prompt +reqs <- lapply(texts, function(text) { + messages <- append_message(text, "user", chat_history) + chat("llama3.1", messages, output = "req") +}) + +# make parallel requests and get response +resps <- req_perform_parallel(reqs) # list of httr2_request objects + +# process the responses +bind_rows(lapply(resps, resp_process, "df")) # get responses as dataframes +# # A tibble: 3 × 4 +# model role content created_at +# +# 1 llama3.1 assistant Positive 2024-08-05T17:54:27.758618Z +# 2 llama3.1 assistant negative 2024-08-05T17:54:27.657525Z +# 3 llama3.1 assistant other 2024-08-05T17:54:27.657067Z + +``` diff --git a/README.md b/README.md index 4f8d694..45c69f9 100644 --- a/README.md +++ b/README.md @@ -323,3 +323,88 @@ messages[[4]] # get 2nd message # delete a message at a specific index/position (2nd position in the example below) messages <- delete_message(messages, 2) ``` + +## Advanced usage + +### Parallel requests + +For the `generate()` and `chat()` endpoints/functions, you can make +parallel requests with the `req_perform_parallel` function from the +`httr2` library. You need to specify `output = 'req'` in the function so +the functions return `httr2_request` objects instead of `httr2_response` +objects. + +``` r +library(httr2) + +prompt <- "Tell me a 5-word story" + +# create 5 httr2_request objects that generate a response to the same prompt +reqs <- lapply(1:5, function(r) generate("llama3.1", prompt, output = "req")) + +# make parallel requests and get response +resps <- req_perform_parallel(reqs) # list of httr2_request objects + +# process the responses +sapply(resps, resp_process, "text") # get responses as text +# [1] "She found him in Paris." "She found the key upstairs." +# [3] "She found her long-lost sister." "She found love on Mars." +# [5] "She found the diamond ring." +``` + +Example sentiment analysis with parallel requests with `generate()` +function + +``` r +library(httr2) +library(glue) +library(dplyr) + +# text to classify +texts <- c('I love this product', 'I hate this product', 'I am neutral about this product') + +# create httr2_request objects for each text, using the same system prompt +reqs <- lapply(texts, function(text) { + prompt <- glue("Your only task/role is to evaluate the sentiment of product reviews, and your response should be one of the following:'positive', 'negative', or 'other'. Product review: {text}") + generate("llama3.1", prompt, output = "req") +}) + +# make parallel requests and get response +resps <- req_perform_parallel(reqs) # list of httr2_request objects + +# process the responses +sapply(resps, resp_process, "text") # get responses as text +# [1] "Positive" "Negative." +# [3] "'neutral' translates to... 'other'." +``` + +Example sentiment analysis with parallel requests with `chat()` function + +``` r +library(httr2) +library(dplyr) + +# text to classify +texts <- c('I love this product', 'I hate this product', 'I am neutral about this product') + +# create system prompt +chat_history <- create_message("Your only task/role is to evaluate the sentiment of product reviews provided by the user. Your response should simply be 'positive', 'negative', or 'other'.", "system") + +# create httr2_request objects for each text, using the same system prompt +reqs <- lapply(texts, function(text) { + messages <- append_message(text, "user", chat_history) + chat("llama3.1", messages, output = "req") +}) + +# make parallel requests and get response +resps <- req_perform_parallel(reqs) # list of httr2_request objects + +# process the responses +bind_rows(lapply(resps, resp_process, "df")) # get responses as dataframes +# # A tibble: 3 × 4 +# model role content created_at +# +# 1 llama3.1 assistant Positive 2024-08-05T17:54:27.758618Z +# 2 llama3.1 assistant negative 2024-08-05T17:54:27.657525Z +# 3 llama3.1 assistant other 2024-08-05T17:54:27.657067Z +``` diff --git a/man/chat.Rd b/man/chat.Rd index 9319007..4b1c925 100644 --- a/man/chat.Rd +++ b/man/chat.Rd @@ -10,7 +10,7 @@ chat( tools = list(), stream = FALSE, keep_alive = "5m", - output = c("resp", "jsonlist", "raw", "df", "text"), + output = c("resp", "jsonlist", "raw", "df", "text", "req"), endpoint = "/api/chat", host = NULL, ... @@ -27,7 +27,7 @@ chat( \item{keep_alive}{The duration to keep the connection alive. Default is "5m".} -\item{output}{The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text".} +\item{output}{The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text", "req" (httr2_request object).} \item{endpoint}{The endpoint to chat with the model. Default is "/api/chat".} diff --git a/man/generate.Rd b/man/generate.Rd index f166350..e0a0af8 100644 --- a/man/generate.Rd +++ b/man/generate.Rd @@ -15,7 +15,7 @@ generate( stream = FALSE, raw = FALSE, keep_alive = "5m", - output = c("resp", "jsonlist", "raw", "df", "text"), + output = c("resp", "jsonlist", "raw", "df", "text", "req"), endpoint = "/api/generate", host = NULL, ... @@ -42,7 +42,7 @@ generate( \item{keep_alive}{The time to keep the connection alive. Default is "5m" (5 minutes).} -\item{output}{A character vector of the output format. Default is "resp". Options are "resp", "jsonlist", "raw", "df", "text".} +\item{output}{A character vector of the output format. Default is "resp". Options are "resp", "jsonlist", "raw", "df", "text", "req" (httr2_request object).} \item{endpoint}{The endpoint to generate the completion. Default is "/api/generate".} diff --git a/tests/testthat/test-chat.R b/tests/testthat/test-chat.R index e4fd9cd..1f69847 100644 --- a/tests/testthat/test-chat.R +++ b/tests/testthat/test-chat.R @@ -11,6 +11,8 @@ test_that("chat function works with basic input", { # incorrect output type expect_error(chat("llama3", messages, output = "abc")) + expect_s3_class(chat("llama3.1", messages, output = "req"), "httr2_request") + # not streaming expect_s3_class(chat("llama3", messages), "httr2_response") expect_s3_class(chat("llama3", messages, output = "resp"), "httr2_response") diff --git a/tests/testthat/test-generate.R b/tests/testthat/test-generate.R index dd1cddd..fff5d4c 100644 --- a/tests/testthat/test-generate.R +++ b/tests/testthat/test-generate.R @@ -7,6 +7,8 @@ test_that("generate function works with different outputs and resp_process", { # incorrect output type expect_error(generate("llama3", "The sky is...", output = "abc")) + expect_s3_class(generate("llama3.1", "tell me a 5-word story", output = "req"), "httr2_request") + # not streaming expect_s3_class(generate("llama3", "The sky is..."), "httr2_response") expect_s3_class(generate("llama3", "The sky is...", output = "resp"), "httr2_response")