Skip to content

Commit

Permalink
Update docs for chat and generate #8
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin committed Aug 17, 2024
1 parent 81041ac commit fe88b38
Show file tree
Hide file tree
Showing 12 changed files with 260 additions and 6 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export(delete)
export(delete_message)
export(embed)
export(embeddings)
export(encode_images_in_messages)
export(generate)
export(image_encode_base64)
export(insert_message)
Expand All @@ -26,6 +27,8 @@ export(resp_process)
export(search_options)
export(show)
export(test_connection)
export(validate_message)
export(validate_messages)
export(validate_options)
importFrom(crayon,green)
importFrom(crayon,red)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# ollamar (development version)

- `generate()` and `chat()` accept multiple images as prompts/messages.
- Add functions to validate messages for `chat()` function: `validate_message()`, `validate_messages()`.
- Add `encode_images_in_messages()` to encode images in messages for `chat()` function.

# ollamar 1.2.0

Expand Down
13 changes: 13 additions & 0 deletions R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
#' list(role = "user", content = "List all the previous messages.")
#' )
#' chat("llama3", messages, stream = TRUE)
#'
#' # image
#' image_path <- file.path(system.file("extdata", package = "ollamar"), "image1.png")
#' messages <- list(
#' list(role = "user", content = "What is in the image?", images = image_path)
#' )
#' chat("benzie/llava-phi-3", messages, output = 'text')
chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req"), endpoint = "/api/chat", host = NULL, ...) {
output <- output[1]
if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req")) {
Expand All @@ -210,6 +217,12 @@ chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = "
req <- create_request(endpoint, host)
req <- httr2::req_method(req, "POST")

if (!validate_messages(messages)) {
stop("Invalid messages.")
}

messages <- encode_images_in_messages(messages)

body_json <- list(
model = model,
messages = messages,
Expand Down
99 changes: 99 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,102 @@ delete_message <- function(x, position = -1) {
if (position < 0) position <- length(x) + position + 1
return(x[-position])
}




#' Validate a message
#'
#' Validate a message to ensure it has the required fields and the correct data types for the `chat()` function.
#' @param message A list with a single message of list class.
#'
#' @return TRUE if message is valid, otherwise an error is thrown.
#' @export
#'
#' @examples
#' validate_message(list(role = "user", content = "Hello"))
validate_message <- function(message) {
if (!is.list(message)) {
stop("Message must be list.")
}
if (!all(c("role", "content") %in% names(message))) {
stop("Message must have role and content.")
}
if (!is.character(message$role)) {
stop("Message role must be character.")
}
if (!is.character(message$content)) {
stop("Message content must be character.")
}
return(TRUE)
}



#' Validate a list of messages
#'
#' Validate a list of messages to ensure they have the required fields and the correct data types for the `chat()` function.
#'
#' @param messages A list of messages, each of list class.
#'
#' @return TRUE if all messages are valid, otherwise warning messages are printed and FALSE is returned.
#' @export
#'
#' @examples
#' validate_messages(list(
#' list(role = "system", content = "Be friendly"),
#' list(role = "user", content = "Hello")
#' ))
validate_messages <- function(messages) {
status <- TRUE
for (i in 1:length(messages)) {
tryCatch({
validate_message(messages[[i]])
}, error = function(e) {
status <<- FALSE
message(paste0("Message ", i, ": ", conditionMessage(e)))
})
}
return(status)
}



#' Encode images in messages to base64 format
#'
#' @param messages A list of messages, each of list class. Generally used in the `chat()` function.
#'
#' @return A list of messages with images encoded in base64 format.
#' @export
#'
#' @examples
#' image <- file.path(system.file("extdata", package = "ollamar"), "image1.png")
#' messages <- list(
#' list(role = "user", content = "what is in the image?", images = image)
#' )
#' messages_updated <- encode_images_in_messages(messages)
encode_images_in_messages <- function(messages) {
if (!validate_messages(messages)) {
stop("Invalid messages.")
}

for (i in 1:length(messages)) {
message <- messages[[i]]
if ("images" %in% names(message)) {
images <- message$images
if (images[1] != "") {
message$images <- lapply(images, image_encode_base64)
messages[[i]] <- message
} else {
next
}
}
}

# revalidate messages
if (!validate_messages(messages)) {
stop("Invalid messages.")
}

return(messages)
}
3 changes: 3 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ reference:
- prepend_message
- delete_message
- insert_message
- validate_message
- validate_messages
- encode_images_in_messages

- subtitle: Model options
desc: Functions to get information about the options available.
Expand Down
7 changes: 7 additions & 0 deletions man/chat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions man/encode_images_in_messages.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions man/validate_message.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions man/validate_messages.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions tests/testthat/test-chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,32 @@ test_that("chat function handles additional options", {
expect_type(result_creative, "character")
expect_error(chat("llama3", messages, output = "text", abc = 2.0))
})


test_that("chat function handles images in messages", {
skip_if_not(test_connection()$status_code == 200, "Ollama server not available")
skip_if_not(model_avail("benzie/llava-phi-3"), "benzie/llava-phi-3 model not available")

images <- c(file.path(system.file("extdata", package = "ollamar"), "image1.png"),
file.path(system.file("extdata", package = "ollamar"), "image2.png"))

# 1 image
messages <- list(
list(role = "system", content = "You have to evaluate what objects are in images."),
list(role = "user", content = "what is in the image?", images = images[2])
)

result <- chat("benzie/llava-phi-3", messages, output = "text")
expect_match(tolower(result), "cam")

# multiple images
messages <- list(
list(role = "system", content = "You have to evaluate what objects are in the two images."),
list(role = "user", content = "what objects are in the two separate images?", images = images)
)

result <- chat("benzie/llava-phi-3", messages, output = "text")
expect_type(result, "character")
expect_true(grepl("melon", tolower(result)) | grepl("cam", tolower(result)))

})
13 changes: 7 additions & 6 deletions tests/testthat/test-generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,27 +53,28 @@ test_that("generate function works with additional options", {
})


# Note for the following test to work you need to make sure the "benzie/llava-phi-3:latest" model exists locally

test_that("generate function works with images", {
skip_if_not(test_connection()$status_code == 200, "Ollama server not available")
skip_if_not(model_avail("benzie/llava-phi-3"), "benzie/llava-phi-3 model not available")

image_path <- file.path(system.file("extdata", package = "ollamar"), "image1.png")

result <- generate("benzie/llava-phi-3:latest", "What is in the image?", images = image_path)
result <- generate("benzie/llava-phi-3", "What is in the image?", images = image_path)
expect_s3_class(result, "httr2_response")
expect_type(resp_process(result, "text"), "character")
expect_match(tolower(resp_process(result, "text")), "watermelon")

expect_error(generate("benzie/llava-phi-3:latest", "What is in the image?", images = "incorrect_path.png"))
expect_error(generate("benzie/llava-phi-3", "What is in the image?", images = "incorrect_path.png"))

images <- c(file.path(system.file("extdata", package = "ollamar"), "image1.png"),
file.path(system.file("extdata", package = "ollamar"), "image2.png"))

# multiple images
result <- generate("benzie/llava-phi-3:latest", "What objects are in the two images?", images = images)
expect_s3_class(result, "httr2_response")
expect_type(resp_process(result, "text"), "character")
result <- generate("benzie/llava-phi-3", "What objects are in the two images?",
images = images, output = 'text')
expect_type(result, "character")
expect_true(grepl("melon", tolower(result)) | grepl("cam", tolower(result)))

})

Expand Down
30 changes: 30 additions & 0 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,34 @@ test_that("copy function works with basic input", {
expect_true(msg5[[1]]$role == "2" & msg5[[2]]$role == "2.1" & msg5[[3]]$role == "4")
expect_true(msg5[[1]]$content == "hello2" & msg5[[2]]$content == "hello2.1" & msg5[[3]]$content == "hello4")


expect_true(validate_message(list(role = "user", content = "hello")))
expect_error(validate_message(""))
expect_error(validate_message(list(role = "user")))
expect_error(validate_message(list(content = "hello")))
expect_error(validate_message(list(role = 1, content = "hello")))
expect_error(validate_message(list(role = "user", content = 1)))


expect_true(validate_messages(list(
list(role = "user", content = "hello")
)))
expect_true(validate_messages(list(
list(role = "system", content = "hello"),
list(role = "user", content = "hello")
)))
expect_false(validate_messages(list(
list(role = "system", content = "hello"),
list(role = "user", content = 1)
)))

images <- c(file.path(system.file("extdata", package = "ollamar"), "image1.png"),
file.path(system.file("extdata", package = "ollamar"), "image2.png"))

expect_type(encode_images_in_messages(list(
list(role = "user", content = "hello", images = images[1]),
list(role = "user", content = "hello"),
list(role = "user", content = "hello", images = "")
)), "list")

})

0 comments on commit fe88b38

Please sign in to comment.