From d9790fc40235cf55530b8c8c476e9878f45d4ad4 Mon Sep 17 00:00:00 2001 From: BERENZ Date: Tue, 7 May 2024 08:12:13 +0200 Subject: [PATCH] changed the defaults for nnd algorithm and other changes --- DESCRIPTION | 1 + NAMESPACE | 2 +- R/blocking.R | 22 +++++++++++----------- R/controls.R | 2 +- R/method_nnd.R | 20 ++++++++++---------- R/methods.R | 2 +- README.md | 22 ++++++++++++---------- inst/tinytest/test_annoy.R | 2 ++ inst/tinytest/test_blocking.R | 2 +- inst/tinytest/test_hnsw.R | 2 ++ inst/tinytest/test_mlpack.R | 4 ++++ inst/tinytest/test_reclin2.R | 2 +- inst/tinytest/test_true_blocks.R | 0 man/blocking.Rd | 1 + man/controls_ann.Rd | 2 +- vignettes/v2-reclin.Rmd | 10 +++++----- 16 files changed, 54 insertions(+), 42 deletions(-) create mode 100644 inst/tinytest/test_true_blocks.R diff --git a/DESCRIPTION b/DESCRIPTION index a85a1a4..8ff08af 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -25,6 +25,7 @@ Imports: rnndescent, igraph, data.table, + RcppAlgos, methods Suggests: tinytest, diff --git a/NAMESPACE b/NAMESPACE index d744f92..3af2e01 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,6 +6,7 @@ export(controls_ann) export(controls_txt) export(pair_ann) import(data.table) +importFrom(RcppAlgos,comboGeneral) importFrom(RcppAnnoy,AnnoyAngular) importFrom(RcppAnnoy,AnnoyEuclidean) importFrom(RcppAnnoy,AnnoyHamming) @@ -31,6 +32,5 @@ importFrom(text2vec,create_vocabulary) importFrom(text2vec,itoken) importFrom(text2vec,itoken_parallel) importFrom(text2vec,vocab_vectorizer) -importFrom(utils,combn) importFrom(utils,setTxtProgressBar) importFrom(utils,txtProgressBar) diff --git a/R/blocking.R b/R/blocking.R index 6ba6609..4883cb1 100644 --- a/R/blocking.R +++ b/R/blocking.R @@ -9,7 +9,7 @@ #' @importFrom igraph graph_from_data_frame #' @importFrom igraph make_clusters #' @importFrom igraph compare -#' @importFrom utils combn +#' @importFrom RcppAlgos comboGeneral #' #' #' @title Block records based on text data. @@ -41,6 +41,7 @@ #' \itemize{ #' \item{\code{result} -- \code{data.table} with indices (rows) of x, y, block and distance between points} #' \item{\code{method} -- name of the ANN algorithm used,} +#' \item{\code{deduplication} -- information whether deduplication was applied,} #' \item{\code{metrics} -- metrics for quality assessment, if \code{true_blocks} is provided,} #' \item{\code{colnames} -- variable names (colnames) used for search,} #' \item{\code{graph} -- \code{igraph} class object.} @@ -96,12 +97,11 @@ blocking <- function(x, is.character(x) | is.matrix(x) | inherits(x, "Matrix")) ## assuming rows (for nnd) - stopifnot("Minimum 3 cases required for x" = NROW(x) > 2) - - if (!is.null(y)) { - stopifnot("Minimum 3 cases required for y" = NROW(y) > 2) - } - + # stopifnot("Minimum 3 cases required for x" = NROW(x) > 2) + # + # if (!is.null(y)) { + # stopifnot("Minimum 3 cases required for y" = NROW(y) > 2) + # } if (!is.null(ann_write)) { stopifnot("Path provided in the `ann_write` is incorrect" = file.exists(ann_write) ) @@ -314,11 +314,10 @@ blocking <- function(x, } - #consider using RcppAlgos::comboGeneral(nrow(pairs_to_eval_long), 2, nThreads=n_threads) - candidate_pairs <- utils::combn(nrow(pairs_to_eval_long), 2) + candidate_pairs <- RcppAlgos::comboGeneral(nrow(pairs_to_eval_long), 2, nThreads=n_threads) - same_block <- pairs_to_eval_long$block_id[candidate_pairs[1, ]] == pairs_to_eval_long$block_id[candidate_pairs[2, ]] - same_truth <- pairs_to_eval_long$true_id[candidate_pairs[1, ]] == pairs_to_eval_long$true_id[candidate_pairs[2, ]] + same_block <- pairs_to_eval_long$block_id[candidate_pairs[, 1]] == pairs_to_eval_long$block_id[candidate_pairs[,2]] + same_truth <- pairs_to_eval_long$true_id[candidate_pairs[,1]] == pairs_to_eval_long$true_id[candidate_pairs[,2]] confusion <- table(same_block, same_truth) @@ -343,6 +342,7 @@ blocking <- function(x, method = ann, deduplication = deduplication, metrics = if (is.null(true_blocks)) NULL else eval_metrics, + confusion = if (is.null(true_blocks)) NULL else confusion, colnames = colnames_xy, graph = if (graph) { igraph::graph_from_data_frame(x_df[, c("x", "y")], directed = F) diff --git a/R/controls.R b/R/controls.R index f7ec116..5261de0 100644 --- a/R/controls.R +++ b/R/controls.R @@ -20,7 +20,7 @@ controls_ann <- function( sparse = FALSE, k_search = 30, nnd = list(k_build = 30, - use_alt_metric = TRUE, + use_alt_metric = FALSE, init = "tree", n_trees = NULL, leaf_size = NULL, diff --git a/R/method_nnd.R b/R/method_nnd.R index e2e0b4a..74ebe87 100644 --- a/R/method_nnd.R +++ b/R/method_nnd.R @@ -54,19 +54,19 @@ method_nnd <- function(x, ## query k dependent on the study ## there is a problem when dataset is small - if (deduplication == T) { - k_nnd_query <- k - } else if (nrow(x) < 10) { - k_nnd_query <- k - } else if (nrow(x) < control$k_search) { - k_nnd_query <- nrow(x) - } else { - k_nnd_query <- control$k_search - } +# if (deduplication == T) { +# k_nnd_query <- k +# } else if (nrow(x) < 10) { +# k_nnd_query <- k +# } else if (nrow(x) < control$k_search) { +# k_nnd_query <- nrow(x) +# } else { +# k_nnd_query <- control$k_search +# } l_1nn <- rnndescent::rnnd_query(index = l_ind, query = y, - k = k_nnd_query, + k = if (nrow(x) < control$k_search) nrow(x) else control$k_search, epsilon = 0.1, max_search_fraction = 1, init = NULL, diff --git a/R/methods.R b/R/methods.R index 0447fb4..8796737 100644 --- a/R/methods.R +++ b/R/methods.R @@ -21,7 +21,7 @@ print.blocking <- function(x,...) { cat("========================================================\n") cat("Evaluation metrics (standard):\n" ) metrics <- as.numeric(sprintf("%.4f", x$metrics*100)) - names(metrics) <- names(result2$metrics) + names(metrics) <- names(x$metrics) print(metrics) } diff --git a/README.md b/README.md index f8f9642..44ca976 100644 --- a/README.md +++ b/README.md @@ -131,11 +131,11 @@ Table with blocking results contains: blocking_result$result #> x y block dist #> -#> 1: 1 2 1 0.10000005 -#> 2: 1 3 1 0.14188367 -#> 3: 1 4 1 0.28286284 -#> 4: 5 6 2 0.08333336 -#> 5: 5 7 2 0.13397458 +#> 1: 1 2 1 0.10000002 +#> 2: 2 3 1 0.14188367 +#> 3: 2 4 1 0.28286284 +#> 4: 5 6 2 0.08333331 +#> 5: 5 7 2 0.13397455 #> 6: 5 8 2 0.27831215 ``` @@ -148,17 +148,19 @@ pair_ann(x = df_example, on = "txt") |> score_simple("score", on = "txt") |> select_threshold("threshold", score = "score", threshold = 0.55) |> link(selection = "threshold") -#> Total number of pairs: 6 pairs +#> Total number of pairs: 8 pairs #> #> Key: <.y> #> .y .x txt.x txt.y #> #> 1: 2 1 jankowalski kowalskijan #> 2: 3 1 jankowalski kowalskimjan -#> 3: 4 1 jankowalski kowaljan -#> 4: 6 5 montypython pythonmonty -#> 5: 7 5 montypython cyrkmontypython -#> 6: 8 5 montypython monty +#> 3: 3 2 kowalskijan kowalskimjan +#> 4: 4 1 jankowalski kowaljan +#> 5: 4 2 kowalskijan kowaljan +#> 6: 6 5 montypython pythonmonty +#> 7: 7 5 montypython cyrkmontypython +#> 8: 8 5 montypython monty ``` Linking records using the same function where `df_base` is the diff --git a/inst/tinytest/test_annoy.R b/inst/tinytest/test_annoy.R index 3877aea..d9d52be 100644 --- a/inst/tinytest/test_annoy.R +++ b/inst/tinytest/test_annoy.R @@ -24,6 +24,7 @@ expect_equal( method = "annoy", deduplication = FALSE, metrics = NULL, + confusion = NULL, colnames = c("al", "an", "ho", "ij", "ja", "ki", "ko", "ls", "mo", "nt", "ow", "py", "sk", "ty", "wa", "yp", "yt", "on", "th"), graph = NULL), @@ -50,6 +51,7 @@ expect_equal( method = "annoy", deduplication = FALSE, metrics = NULL, + confusion = NULL, colnames = c("al", "an", "ho", "ij", "ja", "ki", "ko", "ls", "mo", "ow", "py", "sk", "ty", "wa", "yp", "yt", "nt", "on", "th"), graph = NULL), diff --git a/inst/tinytest/test_blocking.R b/inst/tinytest/test_blocking.R index 71e75d7..8583665 100644 --- a/inst/tinytest/test_blocking.R +++ b/inst/tinytest/test_blocking.R @@ -53,7 +53,7 @@ expect_equal( expect_equal( blocking(x = df_base$txt, y = df_example$txt, ann = "lsh")$result$block, - c(rep(2,3),rep(1,4), 3) + c(rep(2,3),rep(1,4),3) ) expect_silent( diff --git a/inst/tinytest/test_hnsw.R b/inst/tinytest/test_hnsw.R index f4ac9e6..d62a0c5 100644 --- a/inst/tinytest/test_hnsw.R +++ b/inst/tinytest/test_hnsw.R @@ -25,6 +25,7 @@ expect_equal( method = "hnsw", deduplication = FALSE, metrics = NULL, + confusion = NULL, colnames = c("al", "an", "ho", "ij", "ja", "ki", "ko", "ls", "mo", "ow", "py", "sk", "ty", "wa", "yp", "yt", "nt", "on", "th"), graph = NULL), @@ -56,6 +57,7 @@ expect_equal( method = "hnsw", deduplication = FALSE, metrics = NULL, + confusion = NULL, colnames = c("al", "an", "ho", "ij", "ja", "ki", "ko", "ls", "mo", "nt", "ow", "py", "sk", "ty", "wa", "yp", "yt", "on", "th"), graph = NULL), diff --git a/inst/tinytest/test_mlpack.R b/inst/tinytest/test_mlpack.R index c3869e4..658412a 100644 --- a/inst/tinytest/test_mlpack.R +++ b/inst/tinytest/test_mlpack.R @@ -14,6 +14,7 @@ expect_equal( method = "lsh", deduplication = FALSE, metrics = NULL, + confusion = NULL, colnames = c("al", "an", "ho", "ij", "ja", "ki", "ko", "ls", "mo", "ow", "py", "sk", "ty", "wa", "yp", "yt", "nt", "on", "th"), graph = NULL), @@ -34,6 +35,7 @@ expect_equal( method = "kd", deduplication = FALSE, metrics = NULL, + confusion = NULL, colnames =c("al", "an", "ho", "ij", "ja", "ki", "ko", "ls", "mo", "ow", "py", "sk", "ty", "wa", "yp", "yt", "nt", "on", "th"), graph = NULL), @@ -67,6 +69,7 @@ expect_equal( method = "lsh", deduplication = FALSE, metrics = NULL, + confusion = NULL, colnames = c("al", "an", "ho", "ij", "ja", "ki", "ko", "ls", "mo", "nt", "ow", "py", "sk", "ty", "wa", "yp", "yt", "on", "th"), graph = NULL), @@ -88,6 +91,7 @@ expect_equal( method = "kd", deduplication = FALSE, metrics = NULL, + confusion = NULL, colnames = c("al", "an", "ho", "ij", "ja", "ki", "ko", "ls", "mo", "nt", "ow", "py", "sk", "ty", "wa", "yp", "yt", "on", "th"), graph = NULL), diff --git a/inst/tinytest/test_reclin2.R b/inst/tinytest/test_reclin2.R index 7196695..494f4cf 100644 --- a/inst/tinytest/test_reclin2.R +++ b/inst/tinytest/test_reclin2.R @@ -6,7 +6,7 @@ expect_silent( expect_equal( dim(pair_ann(x = df_example, on = "txt")), - c(6, 3) + c(8, 3) ) expect_equal( diff --git a/inst/tinytest/test_true_blocks.R b/inst/tinytest/test_true_blocks.R new file mode 100644 index 0000000..e69de29 diff --git a/man/blocking.Rd b/man/blocking.Rd index 12ad080..d40fc40 100644 --- a/man/blocking.Rd +++ b/man/blocking.Rd @@ -61,6 +61,7 @@ Returns a list with containing:\cr \itemize{ \item{\code{result} -- \code{data.table} with indices (rows) of x, y, block and distance between points} \item{\code{method} -- name of the ANN algorithm used,} +\item{\code{deduplication} -- information whether deduplication was applied,} \item{\code{metrics} -- metrics for quality assessment, if \code{true_blocks} is provided,} \item{\code{colnames} -- variable names (colnames) used for search,} \item{\code{graph} -- \code{igraph} class object.} diff --git a/man/controls_ann.Rd b/man/controls_ann.Rd index c27fe5d..03daae2 100644 --- a/man/controls_ann.Rd +++ b/man/controls_ann.Rd @@ -7,7 +7,7 @@ controls_ann( sparse = FALSE, k_search = 30, - nnd = list(k_build = 30, use_alt_metric = TRUE, init = "tree", n_trees = NULL, + nnd = list(k_build = 30, use_alt_metric = FALSE, init = "tree", n_trees = NULL, leaf_size = NULL, max_tree_depth = 200, margin = "auto", n_iters = NULL, delta = 0.001, max_candidates = NULL, low_memory = TRUE, n_search_trees = 1, pruning_degree_multiplier = 1.5, diversify_prob = 1, weight_by_degree = FALSE, diff --git a/vignettes/v2-reclin.Rmd b/vignettes/v2-reclin.Rmd index 81086ff..3e39a40 100644 --- a/vignettes/v2-reclin.Rmd +++ b/vignettes/v2-reclin.Rmd @@ -101,7 +101,7 @@ The goal of this exercise is to link units from the CIS dataset to the CENSUS da ```{r} set.seed(2024) -result1 <- blocking(x = census$txt, y = cis$txt, verbose = 1) +result1 <- blocking(x = census$txt, y = cis$txt, verbose = 1, n_threads = 8) ``` Distribution of distances for each pair. @@ -140,7 +140,7 @@ So in our example we have `r nrow(matches)` pairs. ```{r} set.seed(2024) result2 <- blocking(x = census$txt, y = cis$txt, verbose = 1, - true_blocks = matches[, .(x, y, block)], n_threads = 4) + true_blocks = matches[, .(x, y, block)], n_threads = 8) ``` Let's see how our approach handled this problem. @@ -149,7 +149,7 @@ Let's see how our approach handled this problem. result2 ``` -It seems that the default parameters of the NND method result in an FNR of `r sprintf("%.2f",result2$metrics["fnr"]*100)`%, which is quite large. We can see if increasing the number of `k` (and thus `max_candidates`) as suggested in the [Nearest Neighbor Descent +It seems that the default parameters of the NND method result in an FNR of `r sprintf("%.2f",result2$metrics["fnr"]*100)`%. We can see if increasing the number of `k` (and thus `max_candidates`) as suggested in the [Nearest Neighbor Descent ](https://jlmelville.github.io/rnndescent/articles/nearest-neighbor-descent.html) vignette will help. @@ -159,7 +159,7 @@ ann_control_pars <- controls_ann() ann_control_pars$k_search <- 60 result3 <- blocking(x = census$txt, y = cis$txt, verbose = 1, - true_blocks = matches[, .(x, y, block)], n_threads = 4, + true_blocks = matches[, .(x, y, block)], n_threads = 8, control_ann = ann_control_pars) ``` @@ -173,7 +173,7 @@ Finally, compare the NND and HNSW algorithm for this example. ```{r} result4 <- blocking(x = census$txt, y = cis$txt, verbose = 1, - true_blocks = matches[, .(x, y, block)], n_threads = 4, + true_blocks = matches[, .(x, y, block)], n_threads = 8, ann = "hnsw", seed = 2024) ```