Skip to content

Commit

Permalink
Write test cases #19
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin committed Jul 29, 2024
1 parent 9410121 commit 14172b9
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
4 changes: 3 additions & 1 deletion R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ embed <- function(model, input, truncate = TRUE, normalize = TRUE, keep_alive =
{
resp <- httr2::req_perform(req)
json_body <- httr2::resp_body_json(resp)$embeddings
# matrix
m <- do.call(cbind, lapply(json_body, function(x) {
v <- unlist(x)
if (normalize) {
Expand Down Expand Up @@ -584,7 +585,8 @@ embeddings <- function(model, prompt, normalize = TRUE, keep_alive = "5m", endpo
tryCatch(
{
resp <- httr2::req_perform(req)
v <- unlist(resp_process(resp, "jsonlist")$embedding)
# vector
v <- unlist(httr2::resp_body_json(resp)$embedding)
if (normalize) {
v <- normalize(v)
}
Expand Down
24 changes: 24 additions & 0 deletions tests/testthat/test-embeddings.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,28 @@ library(ollamar)

test_that("embeddings function works with basic input", {
skip_if_not(test_connection()$status_code == 200, "Ollama server not available")

result <- embeddings("all-minilm", "hello")
expect_type(result, "double")
expect_true(is.null(dim(result)[2])) # not matrix

# model options
expect_type(embeddings("all-minilm", "hello", temperature = 2), "double")
expect_error(embeddings("all-minilm", "hello", dfdsffds = 0))

# check normalize (default is normalize = TRUE)
result <- embeddings("all-minilm", "hello", normalize = TRUE)
result
expect_true(all.equal(1, vector_norm(result)))
result2 <- embeddings("all-minilm", "hello") # default is normalize = TRUE
expect_true(sum(result - result2) == 0) # result and result2 vectors should be the same

# check unormalize
result3 <- embeddings("all-minilm", "hello", normalize = FALSE)
expect_false(sum(result - result3) == 0) # result and result3 vectors are different

# cosine similarity
expect_true(all.equal(sum(result * result), 1))
expect_true(sum(result * result2) != 1)
expect_true(sum(result * result3) != 1)
})
9 changes: 9 additions & 0 deletions tests/testthat/test-options.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
library(testthat)
library(ollamar)

test_that("model options", {

expect_true(check_option_valid("mirostat"))
expect_false(check_option_valid("sdfadsfdf"))

})

0 comments on commit 14172b9

Please sign in to comment.