Skip to content

Commit

Permalink
Add parallel requests
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin committed Aug 5, 2024
1 parent bab3300 commit 9f5c13b
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 11 deletions.
21 changes: 14 additions & 7 deletions R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ create_request <- function(endpoint, host = NULL) {
#' @param stream Enable response streaming. Default is FALSE.
#' @param raw If TRUE, no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API. Default is FALSE.
#' @param keep_alive The time to keep the connection alive. Default is "5m" (5 minutes).
#' @param output A character vector of the output format. Default is "resp". Options are "resp", "jsonlist", "raw", "df", "text".
#' @param output A character vector of the output format. Default is "resp". Options are "resp", "jsonlist", "raw", "df", "text", "req" (httr2_request object).
#' @param endpoint The endpoint to generate the completion. Default is "/api/generate".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#' @param ... Additional options to pass to the model.
Expand All @@ -80,10 +80,10 @@ create_request <- function(endpoint, host = NULL) {
#' generate("llama3", "The sky is...", stream = TRUE, output = "text")
#' generate("llama3", "The sky is...", stream = TRUE, output = "text", temperature = 2.0)
#' generate("llama3", "The sky is...", stream = FALSE, output = "jsonlist")
generate <- function(model, prompt, suffix = "", images = "", system = "", template = "", context = list(), stream = FALSE, raw = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text"), endpoint = "/api/generate", host = NULL, ...) {
generate <- function(model, prompt, suffix = "", images = "", system = "", template = "", context = list(), stream = FALSE, raw = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req"), endpoint = "/api/generate", host = NULL, ...) {
output <- output[1]
if (!output %in% c("df", "resp", "jsonlist", "raw", "text")) {
stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text'")
if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req")) {
stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text', 'req'")
}

req <- create_request(endpoint, host)
Expand Down Expand Up @@ -118,6 +118,10 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ

req <- httr2::req_body_json(req, body_json, stream = stream)

if (output == "req") {
return(req)
}

if (!stream) {
tryCatch(
{
Expand Down Expand Up @@ -160,7 +164,7 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
#' @param tools Tools for the model to use if supported. Requires stream = FALSE. Default is an empty list.
#' @param stream Enable response streaming. Default is FALSE.
#' @param keep_alive The duration to keep the connection alive. Default is "5m".
#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text".
#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text", "req" (httr2_request object).
#' @param endpoint The endpoint to chat with the model. Default is "/api/chat".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#' @param ... Additional options to pass to the model.
Expand Down Expand Up @@ -191,9 +195,9 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
#' list(role = "user", content = "List all the previous messages.")
#' )
#' chat("llama3", messages, stream = TRUE)
chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text"), endpoint = "/api/chat", host = NULL, ...) {
chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req"), endpoint = "/api/chat", host = NULL, ...) {
output <- output[1]
if (!output %in% c("df", "resp", "jsonlist", "raw", "text")) {
if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req")) {
stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text'")
}

Expand All @@ -218,6 +222,9 @@ chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = "
}

req <- httr2::req_body_json(req, body_json, stream = stream)
if (output == "req") {
return(req)
}

if (!stream) {
tryCatch(
Expand Down
84 changes: 84 additions & 0 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,87 @@ messages[[4]] # get 2nd message
# delete a message at a specific index/position (2nd position in the example below)
messages <- delete_message(messages, 2)
```

## Advanced usage

### Parallel requests

For the `generate()` and `chat()` endpoints/functions, you can make parallel requests with the `req_perform_parallel` function from the `httr2` library. You need to specify `output = 'req'` in the function so the functions return `httr2_request` objects instead of `httr2_response` objects.

```{r eval=FALSE}
library(httr2)
prompt <- "Tell me a 5-word story"
# create 5 httr2_request objects that generate a response to the same prompt
reqs <- lapply(1:5, function(r) generate("llama3.1", prompt, output = "req"))
# make parallel requests and get response
resps <- req_perform_parallel(reqs) # list of httr2_request objects
# process the responses
sapply(resps, resp_process, "text") # get responses as text
# [1] "She found him in Paris." "She found the key upstairs."
# [3] "She found her long-lost sister." "She found love on Mars."
# [5] "She found the diamond ring."
```

Example sentiment analysis with parallel requests with `generate()` function

```{r eval=FALSE}
library(httr2)
library(glue)
library(dplyr)
# text to classify
texts <- c('I love this product', 'I hate this product', 'I am neutral about this product')
# create httr2_request objects for each text, using the same system prompt
reqs <- lapply(texts, function(text) {
prompt <- glue("Your only task/role is to evaluate the sentiment of product reviews, and your response should be one of the following:'positive', 'negative', or 'other'. Product review: {text}")
generate("llama3.1", prompt, output = "req")
})
# make parallel requests and get response
resps <- req_perform_parallel(reqs) # list of httr2_request objects
# process the responses
sapply(resps, resp_process, "text") # get responses as text
# [1] "Positive" "Negative."
# [3] "'neutral' translates to... 'other'."
```

Example sentiment analysis with parallel requests with `chat()` function

```{r eval=FALSE}
library(httr2)
library(dplyr)
# text to classify
texts <- c('I love this product', 'I hate this product', 'I am neutral about this product')
# create system prompt
chat_history <- create_message("Your only task/role is to evaluate the sentiment of product reviews provided by the user. Your response should simply be 'positive', 'negative', or 'other'.", "system")
# create httr2_request objects for each text, using the same system prompt
reqs <- lapply(texts, function(text) {
messages <- append_message(text, "user", chat_history)
chat("llama3.1", messages, output = "req")
})
# make parallel requests and get response
resps <- req_perform_parallel(reqs) # list of httr2_request objects
# process the responses
bind_rows(lapply(resps, resp_process, "df")) # get responses as dataframes
# # A tibble: 3 × 4
# model role content created_at
# <chr> <chr> <chr> <chr>
# 1 llama3.1 assistant Positive 2024-08-05T17:54:27.758618Z
# 2 llama3.1 assistant negative 2024-08-05T17:54:27.657525Z
# 3 llama3.1 assistant other 2024-08-05T17:54:27.657067Z
```
85 changes: 85 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,88 @@ messages[[4]] # get 2nd message
# delete a message at a specific index/position (2nd position in the example below)
messages <- delete_message(messages, 2)
```

## Advanced usage

### Parallel requests

For the `generate()` and `chat()` endpoints/functions, you can make
parallel requests with the `req_perform_parallel` function from the
`httr2` library. You need to specify `output = 'req'` in the function so
the functions return `httr2_request` objects instead of `httr2_response`
objects.

``` r
library(httr2)

prompt <- "Tell me a 5-word story"

# create 5 httr2_request objects that generate a response to the same prompt
reqs <- lapply(1:5, function(r) generate("llama3.1", prompt, output = "req"))

# make parallel requests and get response
resps <- req_perform_parallel(reqs) # list of httr2_request objects

# process the responses
sapply(resps, resp_process, "text") # get responses as text
# [1] "She found him in Paris." "She found the key upstairs."
# [3] "She found her long-lost sister." "She found love on Mars."
# [5] "She found the diamond ring."
```

Example sentiment analysis with parallel requests with `generate()`
function

``` r
library(httr2)
library(glue)
library(dplyr)

# text to classify
texts <- c('I love this product', 'I hate this product', 'I am neutral about this product')

# create httr2_request objects for each text, using the same system prompt
reqs <- lapply(texts, function(text) {
prompt <- glue("Your only task/role is to evaluate the sentiment of product reviews, and your response should be one of the following:'positive', 'negative', or 'other'. Product review: {text}")
generate("llama3.1", prompt, output = "req")
})

# make parallel requests and get response
resps <- req_perform_parallel(reqs) # list of httr2_request objects

# process the responses
sapply(resps, resp_process, "text") # get responses as text
# [1] "Positive" "Negative."
# [3] "'neutral' translates to... 'other'."
```

Example sentiment analysis with parallel requests with `chat()` function

``` r
library(httr2)
library(dplyr)

# text to classify
texts <- c('I love this product', 'I hate this product', 'I am neutral about this product')

# create system prompt
chat_history <- create_message("Your only task/role is to evaluate the sentiment of product reviews provided by the user. Your response should simply be 'positive', 'negative', or 'other'.", "system")

# create httr2_request objects for each text, using the same system prompt
reqs <- lapply(texts, function(text) {
messages <- append_message(text, "user", chat_history)
chat("llama3.1", messages, output = "req")
})

# make parallel requests and get response
resps <- req_perform_parallel(reqs) # list of httr2_request objects

# process the responses
bind_rows(lapply(resps, resp_process, "df")) # get responses as dataframes
# # A tibble: 3 × 4
# model role content created_at
# <chr> <chr> <chr> <chr>
# 1 llama3.1 assistant Positive 2024-08-05T17:54:27.758618Z
# 2 llama3.1 assistant negative 2024-08-05T17:54:27.657525Z
# 3 llama3.1 assistant other 2024-08-05T17:54:27.657067Z
```
4 changes: 2 additions & 2 deletions man/chat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/generate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions tests/testthat/test-chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ test_that("chat function works with basic input", {
# incorrect output type
expect_error(chat("llama3", messages, output = "abc"))

expect_s3_class(chat("llama3.1", messages, output = "req"), "httr2_request")

# not streaming
expect_s3_class(chat("llama3", messages), "httr2_response")
expect_s3_class(chat("llama3", messages, output = "resp"), "httr2_response")
Expand Down
2 changes: 2 additions & 0 deletions tests/testthat/test-generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ test_that("generate function works with different outputs and resp_process", {
# incorrect output type
expect_error(generate("llama3", "The sky is...", output = "abc"))

expect_s3_class(generate("llama3.1", "tell me a 5-word story", output = "req"), "httr2_request")

# not streaming
expect_s3_class(generate("llama3", "The sky is..."), "httr2_response")
expect_s3_class(generate("llama3", "The sky is...", output = "resp"), "httr2_response")
Expand Down

0 comments on commit 9f5c13b

Please sign in to comment.