From 3fb7a8c3281509a7cfc5fd15c1fd7b1ad37c3632 Mon Sep 17 00:00:00 2001 From: MariaHei Date: Mon, 8 Apr 2024 15:29:03 +0100 Subject: [PATCH 1/5] Add warnings to eval.jl Deprecate eval_SC_chucks in favour of eval_SC_chunks (fixing typo) and adding warning regarding homophones/homographs in accuracy_comprehension --- src/eval.jl | 88 +++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 68 insertions(+), 20 deletions(-) diff --git a/src/eval.jl b/src/eval.jl index ff5f9d7..469881a 100644 --- a/src/eval.jl +++ b/src/eval.jl @@ -25,7 +25,8 @@ function eval_SC_loose end """ accuracy_comprehension(S, Shat, data) -Evaluate comprehension accuracy. +Evaluate comprehension accuracy for training data. +NOTE: in case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! See below for more information. # Obligatory Arguments - `S::Matrix`: the (gold standard) S matrix @@ -47,16 +48,19 @@ accuracy_comprehension( base=[:Lexeme], inflections=[:Person, :Number, :Tense, :Voice, :Mood] ) - -accuracy_comprehension( - S_val, - Shat_val, - latin_train, - target_col=:Words, - base=["Lexeme"], - inflections=[:Person, :Number, :Tense, :Voice, :Mood] - ) ``` + +# Note +In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! +Consider the following example: The wordform "Äpfel" in German can be nominative plural, genitive plural and accusative plural. +Let's assume we have a dataset in which "Äpfel" occurs in all three case/number combinations (i.e. there are homographs). +If all these wordforms have the same semantic vectors (e.g. because they are derived from word2vec or fasttext which typically +have a single vector per unique wordform), the predicted semantic vector of the wordform "Äpfel" will be equally correlated +with all three case/number combinations in the dataset. In such cases, while the algorithm in this function can unambiguously +conclude that the correct surface form "Äpfel" was comprehended, which of the three possible rows is the correct one will be +picked somewhat non-deterministically (see https://docs.julialang.org/en/v1/base/collections/#Base.argmax). It is thus possible +that the algorithm will then use the genitive plural instead of the intended nominative plural as the ground plural, and will +report that "case" was comprehended incorrectly. """ function accuracy_comprehension( S, @@ -78,10 +82,16 @@ function accuracy_comprehension( dfr.r_target = corMat[diagind(corMat)] dfr.correct = [dfr.target[i] == dfr.form[i] for i = 1:size(dfr, 1)] + if length(data[:, target_col]) != length(Set(data[:, target_col])) + @warn "This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information." + end + if !isnothing(inflections) all_features = vcat(base, inflections) - else + elseif !isnothing(base) all_features = base + else + all_features = [] end for f in all_features @@ -110,7 +120,9 @@ end inflections = nothing, ) -Evaluate comprehension accuracy. +Evaluate comprehension accuracy for validation data. +NOTE: in case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! See below for more information. + # Obligatory Arguments - `S_val::Matrix`: the (gold standard) S matrix of the validation data @@ -137,6 +149,18 @@ accuracy_comprehension( inflections=[:Person, :Number, :Tense, :Voice, :Mood] ) ``` + +# Note +In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! +Consider the following example: The wordform "Äpfel" in German can be nominative plural, genitive plural and accusative plural. +Let's assume we have a dataset in which "Äpfel" occurs in all three case/number combinations (i.e. there are homographs). +If all these wordforms have the same semantic vectors (e.g. because they are derived from word2vec or fasttext which typically +have a single vector per unique wordform), the predicted semantic vector of the wordform "Äpfel" will be equally correlated +with all three case/number combinations in the dataset. In such cases, while the algorithm in this function can unambiguously +conclude that the correct surface form "Äpfel" was comprehended, which of the three possible rows is the correct one will be +picked somewhat non-deterministically (see https://docs.julialang.org/en/v1/base/collections/#Base.argmax). It is thus possible +that the algorithm will then use the genitive plural instead of the intended nominative plural as the ground plural, and will +report that "case" was comprehended incorrectly. """ function accuracy_comprehension( S_val, @@ -160,6 +184,10 @@ function accuracy_comprehension( append!(data_combined, data_train, promote=true) + if length(data_combined[:, target_col]) != length(Set(data_combined[:, target_col])) + @warn "This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information." + end + corMat = cor(Shat_val, S, dims = 2) top_index = [i[2] for i in argmax(corMat, dims = 2)] @@ -435,7 +463,7 @@ function eval_SC( # for first parts for j = 1:num_chucks-1 - correct += eval_SC_chucks( + correct += eval_SC_chunks( SChat_d, SC_d, (j - 1) * batch_size + 1, @@ -445,7 +473,7 @@ function eval_SC( verbose && ProgressMeter.next!(pb) end # for last part - correct += eval_SC_chucks( + correct += eval_SC_chunks( SChat_d, SC_d, (num_chucks - 1) * batch_size + 1, @@ -504,7 +532,7 @@ function eval_SC( # for first parts for j = 1:num_chucks-1 - correct += eval_SC_chucks( + correct += eval_SC_chunks( SChat_d, SC_d, (j - 1) * batch_size + 1, @@ -516,7 +544,7 @@ function eval_SC( verbose && ProgressMeter.next!(pb) end # for last part - correct += eval_SC_chucks( + correct += eval_SC_chunks( SChat_d, SC_d, (num_chucks - 1) * batch_size + 1, @@ -529,13 +557,18 @@ function eval_SC( round(correct / l, digits=digits) end -function eval_SC_chucks(SChat, SC, s, e, batch_size) +function eval_SC_chunks(SChat, SC, s, e, batch_size) rSC = cor(SChat[s:e, :], SC, dims = 2) v = [(rSC[i[1], i[1]+s-1] == rSC[i]) ? 1 : 0 for i in argmax(rSC, dims = 2)] sum(v) end -function eval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col) +function eval_SC_chucks(SChat, SC, s, e, batch_size) + @warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks" + eval_SC_chunks(SChat, SC, s, e, batch_size) +end + +function eval_SC_chunks(SChat, SC, s, e, batch_size, data, target_col) rSC = cor(SChat[s:e, :], SC, dims = 2) v = [ data[i[1]+s-1, target_col] == data[i[2], target_col] ? 1 : 0 @@ -544,13 +577,23 @@ function eval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col) sum(v) end -function eval_SC_chucks(SChat, SC, s, batch_size) +function eval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col) + @warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks" + eval_SC_chunks(SChat, SC, s, e, batch_size, data, target_col) +end + +function eval_SC_chunks(SChat, SC, s, batch_size) rSC = cor(SChat[s:end, :], SC, dims = 2) v = [(rSC[i[1], i[1]+s-1] == rSC[i]) ? 1 : 0 for i in argmax(rSC, dims = 2)] sum(v) end -function eval_SC_chucks(SChat, SC, s, batch_size, data, target_col) +function eval_SC_chucks(SChat, SC, s, batch_size) + @warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks" + eval_SC_chunks(SChat, SC, s, batch_size) +end + +function eval_SC_chunks(SChat, SC, s, batch_size, data, target_col) rSC = cor(SChat[s:end, :], SC, dims = 2) v = [ data[i[1]+s-1, target_col] == data[i[2], target_col] ? 1 : 0 @@ -559,6 +602,11 @@ function eval_SC_chucks(SChat, SC, s, batch_size, data, target_col) sum(v) end +function eval_SC_chucks(SChat, SC, s, batch_size, data, target_col) + @warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks" + eval_SC_chunks(SChat, SC, s, batch_size, data, target_col) +end + """ eval_SC_loose(SChat, SC, k) From bfafd4d6a43cac68ee6395af128f449fd358bfe3 Mon Sep 17 00:00:00 2001 From: MariaHei Date: Mon, 8 Apr 2024 15:49:53 +0100 Subject: [PATCH 2/5] Add warnings to eval.jl Deprecate eval_SC_chucks in favour of eval_SC_chunks (fixing typo) and adding warning regarding homophones/homographs in accuracy_comprehension --- src/eval.jl | 92 +++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 72 insertions(+), 20 deletions(-) diff --git a/src/eval.jl b/src/eval.jl index ff5f9d7..a2a2037 100644 --- a/src/eval.jl +++ b/src/eval.jl @@ -25,7 +25,10 @@ function eval_SC_loose end """ accuracy_comprehension(S, Shat, data) -Evaluate comprehension accuracy. +Evaluate comprehension accuracy for training data. + +!!! note + In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! See below for more information. # Obligatory Arguments - `S::Matrix`: the (gold standard) S matrix @@ -47,16 +50,19 @@ accuracy_comprehension( base=[:Lexeme], inflections=[:Person, :Number, :Tense, :Voice, :Mood] ) - -accuracy_comprehension( - S_val, - Shat_val, - latin_train, - target_col=:Words, - base=["Lexeme"], - inflections=[:Person, :Number, :Tense, :Voice, :Mood] - ) ``` + +# Note +In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! +Consider the following example: The wordform "Äpfel" in German can be nominative plural, genitive plural and accusative plural. +Let's assume we have a dataset in which "Äpfel" occurs in all three case/number combinations (i.e. there are homographs). +If all these wordforms have the same semantic vectors (e.g. because they are derived from word2vec or fasttext which typically +have a single vector per unique wordform), the predicted semantic vector of the wordform "Äpfel" will be equally correlated +with all three case/number combinations in the dataset. In such cases, while the algorithm in this function can unambiguously +conclude that the correct surface form "Äpfel" was comprehended, which of the three possible rows is the correct one will be +picked somewhat non-deterministically (see https://docs.julialang.org/en/v1/base/collections/#Base.argmax). It is thus possible +that the algorithm will then use the genitive plural instead of the intended nominative plural as the ground plural, and will +report that "case" was comprehended incorrectly. """ function accuracy_comprehension( S, @@ -78,10 +84,16 @@ function accuracy_comprehension( dfr.r_target = corMat[diagind(corMat)] dfr.correct = [dfr.target[i] == dfr.form[i] for i = 1:size(dfr, 1)] + if length(data[:, target_col]) != length(Set(data[:, target_col])) + @warn "This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information." + end + if !isnothing(inflections) all_features = vcat(base, inflections) - else + elseif !isnothing(base) all_features = base + else + all_features = [] end for f in all_features @@ -110,7 +122,11 @@ end inflections = nothing, ) -Evaluate comprehension accuracy. +Evaluate comprehension accuracy for validation data. + +!!! note + In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! See below for more information. + # Obligatory Arguments - `S_val::Matrix`: the (gold standard) S matrix of the validation data @@ -137,6 +153,18 @@ accuracy_comprehension( inflections=[:Person, :Number, :Tense, :Voice, :Mood] ) ``` + +# Note +In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! +Consider the following example: The wordform "Äpfel" in German can be nominative plural, genitive plural and accusative plural. +Let's assume we have a dataset in which "Äpfel" occurs in all three case/number combinations (i.e. there are homographs). +If all these wordforms have the same semantic vectors (e.g. because they are derived from word2vec or fasttext which typically +have a single vector per unique wordform), the predicted semantic vector of the wordform "Äpfel" will be equally correlated +with all three case/number combinations in the dataset. In such cases, while the algorithm in this function can unambiguously +conclude that the correct surface form "Äpfel" was comprehended, which of the three possible rows is the correct one will be +picked somewhat non-deterministically (see https://docs.julialang.org/en/v1/base/collections/#Base.argmax). It is thus possible +that the algorithm will then use the genitive plural instead of the intended nominative plural as the ground plural, and will +report that "case" was comprehended incorrectly. """ function accuracy_comprehension( S_val, @@ -160,6 +188,10 @@ function accuracy_comprehension( append!(data_combined, data_train, promote=true) + if length(data_combined[:, target_col]) != length(Set(data_combined[:, target_col])) + @warn "This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information." + end + corMat = cor(Shat_val, S, dims = 2) top_index = [i[2] for i in argmax(corMat, dims = 2)] @@ -435,7 +467,7 @@ function eval_SC( # for first parts for j = 1:num_chucks-1 - correct += eval_SC_chucks( + correct += eval_SC_chunks( SChat_d, SC_d, (j - 1) * batch_size + 1, @@ -445,7 +477,7 @@ function eval_SC( verbose && ProgressMeter.next!(pb) end # for last part - correct += eval_SC_chucks( + correct += eval_SC_chunks( SChat_d, SC_d, (num_chucks - 1) * batch_size + 1, @@ -504,7 +536,7 @@ function eval_SC( # for first parts for j = 1:num_chucks-1 - correct += eval_SC_chucks( + correct += eval_SC_chunks( SChat_d, SC_d, (j - 1) * batch_size + 1, @@ -516,7 +548,7 @@ function eval_SC( verbose && ProgressMeter.next!(pb) end # for last part - correct += eval_SC_chucks( + correct += eval_SC_chunks( SChat_d, SC_d, (num_chucks - 1) * batch_size + 1, @@ -529,13 +561,18 @@ function eval_SC( round(correct / l, digits=digits) end -function eval_SC_chucks(SChat, SC, s, e, batch_size) +function eval_SC_chunks(SChat, SC, s, e, batch_size) rSC = cor(SChat[s:e, :], SC, dims = 2) v = [(rSC[i[1], i[1]+s-1] == rSC[i]) ? 1 : 0 for i in argmax(rSC, dims = 2)] sum(v) end -function eval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col) +function eval_SC_chucks(SChat, SC, s, e, batch_size) + @warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks" + eval_SC_chunks(SChat, SC, s, e, batch_size) +end + +function eval_SC_chunks(SChat, SC, s, e, batch_size, data, target_col) rSC = cor(SChat[s:e, :], SC, dims = 2) v = [ data[i[1]+s-1, target_col] == data[i[2], target_col] ? 1 : 0 @@ -544,13 +581,23 @@ function eval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col) sum(v) end -function eval_SC_chucks(SChat, SC, s, batch_size) +function eval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col) + @warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks" + eval_SC_chunks(SChat, SC, s, e, batch_size, data, target_col) +end + +function eval_SC_chunks(SChat, SC, s, batch_size) rSC = cor(SChat[s:end, :], SC, dims = 2) v = [(rSC[i[1], i[1]+s-1] == rSC[i]) ? 1 : 0 for i in argmax(rSC, dims = 2)] sum(v) end -function eval_SC_chucks(SChat, SC, s, batch_size, data, target_col) +function eval_SC_chucks(SChat, SC, s, batch_size) + @warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks" + eval_SC_chunks(SChat, SC, s, batch_size) +end + +function eval_SC_chunks(SChat, SC, s, batch_size, data, target_col) rSC = cor(SChat[s:end, :], SC, dims = 2) v = [ data[i[1]+s-1, target_col] == data[i[2], target_col] ? 1 : 0 @@ -559,6 +606,11 @@ function eval_SC_chucks(SChat, SC, s, batch_size, data, target_col) sum(v) end +function eval_SC_chucks(SChat, SC, s, batch_size, data, target_col) + @warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks" + eval_SC_chunks(SChat, SC, s, batch_size, data, target_col) +end + """ eval_SC_loose(SChat, SC, k) From 960593e009e764e09ec9ec9cb398b449968df430 Mon Sep 17 00:00:00 2001 From: MariaHei Date: Mon, 8 Apr 2024 17:19:49 +0100 Subject: [PATCH 3/5] Fix bug in make_combined_L_matrix Make sure sd_inflection_mean and sd_inflection are passed on inside make_combined_L_matrix Fixes #62 --- src/make_semantic_matrix.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/make_semantic_matrix.jl b/src/make_semantic_matrix.jl index 5d2747f..948437d 100644 --- a/src/make_semantic_matrix.jl +++ b/src/make_semantic_matrix.jl @@ -997,7 +997,9 @@ function make_combined_L_matrix( inflections; ncol = ncol, sd_base_mean = sd_base_mean, + sd_inflection_mean = sd_inflection_mean, sd_base = sd_base, + sd_inflection = sd_inflection, seed = seed, isdeep = isdeep, ) @@ -1055,6 +1057,8 @@ function make_combined_L_matrix( ncol = ncol, sd_base_mean = sd_base_mean, sd_base = sd_base, + sd_inflection_mean = sd_inflection_mean, + sd_inflection = sd_inflection, seed = seed, isdeep = isdeep, ) @@ -1224,6 +1228,8 @@ function make_combined_S_matrix( ncol = ncol, sd_base_mean = sd_base_mean, sd_base = sd_base, + sd_inflection_mean = sd_inflection_mean, + sd_inflection = sd_inflection, seed = seed, isdeep = isdeep, ) @@ -1298,6 +1304,8 @@ function make_combined_S_matrix( ncol = ncol, sd_base_mean = sd_base_mean, sd_base = sd_base, + sd_inflection_mean = sd_inflection_mean, + sd_inflection = sd_inflection, seed = seed, isdeep = isdeep, ) From fc159fe02a2cf1cee3680e5993df66892d0690e0 Mon Sep 17 00:00:00 2001 From: MariaHei Date: Wed, 10 Apr 2024 18:20:31 +0100 Subject: [PATCH 4/5] Speed up eval_SC_loose Replace sortperm with partialsortperm Fixes #74 --- src/eval.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/eval.jl b/src/eval.jl index a2a2037..3bbf7d5 100644 --- a/src/eval.jl +++ b/src/eval.jl @@ -640,8 +640,7 @@ function eval_SC_loose(SChat, SC, k; digits=4) ) for i = 1:total - p = sortperm(rSC[i, :], rev = true) - p = p[1:k, :] + p = partialsortperm(rSC[i, :], 1:k, rev = true) if i in p correct += 1 end @@ -681,8 +680,7 @@ function eval_SC_loose(SChat, SC, k, data, target_col; digits=4) ) for i = 1:total - p = sortperm(rSC[i, :], rev = true) - p = p[1:k] + p = partialsortperm(rSC[i, :], 1:k, rev = true) if i in p correct += 1 else From 7c061adc5555614a8f8f746dffc0442d04e90fe5 Mon Sep 17 00:00:00 2001 From: MariaHei Date: Thu, 11 Apr 2024 09:43:16 +0100 Subject: [PATCH 5/5] Add warnings about homophones to eval_SC and eval_SC_loose --- src/eval.jl | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/src/eval.jl b/src/eval.jl index 3bbf7d5..4392266 100644 --- a/src/eval.jl +++ b/src/eval.jl @@ -85,7 +85,7 @@ function accuracy_comprehension( dfr.correct = [dfr.target[i] == dfr.form[i] for i = 1:size(dfr, 1)] if length(data[:, target_col]) != length(Set(data[:, target_col])) - @warn "This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information." + @warn "accuracy_comprehension: This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information." end if !isnothing(inflections) @@ -189,7 +189,7 @@ function accuracy_comprehension( append!(data_combined, data_train, promote=true) if length(data_combined[:, target_col]) != length(Set(data_combined[:, target_col])) - @warn "This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information." + @warn "accuracy_comprehension: This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information." end corMat = cor(Shat_val, S, dims = 2) @@ -232,6 +232,9 @@ Assess model accuracy on the basis of the correlations of row vectors of Chat an C or Shat and S. Ideally the target words have highest correlations on the diagonal of the pertinent correlation matrices. +!!! note + If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and the one on the diagonal will not necessarily be selected as the most correlated. In such cases, supplying the dataset and `target_col` is recommended which enables taking into account homophones/homographs. + # Obligatory Arguments - `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix - `SC::Union{SparseMatrixCSC, Matrix}`: the C or S matrix @@ -248,6 +251,11 @@ eval_SC(Shat_val, S_val) ``` """ function eval_SC(SChat::AbstractArray, SC::AbstractArray; digits=4, R=false) + + if size(unique(SC, dims=1), 1) != size(SC, 1) + @warn "eval_SC: The C or S matrix contains duplicate vectors (usually because of homophones/homographs). Supplying the dataset and target column is recommended for a realistic evaluation. See the documentation of this function for more information." + end + rSC = cor( convert(Matrix{Float64}, SChat), convert(Matrix{Float64}, SC), @@ -273,6 +281,9 @@ of the pertinent correlation matrices. The order is important. The fist gold standard matrix has to be corresponing to the SChat matrix, such as `eval_SC(Shat_train, S_train, S_val)` or `eval_SC(Shat_val, S_val, S_train)` +!!! note + If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and the one on the diagonal will not necessarily be selected as the most correlated. In such cases, supplying the dataset and target_col is recommended which enables taking into account homophones/homographs. + # Obligatory Arguments - `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix - `SC::Union{SparseMatrixCSC, Matrix}`: the training/validation C or S matrix @@ -427,7 +438,10 @@ end Assess model accuracy on the basis of the correlations of row vectors of Chat and C or Shat and S. Ideally the target words have highest correlations on the diagonal of the pertinent correlation matrices. For large datasets, pass batch_size to -process evaluation in chucks. +process evaluation in chunks. + +!!! note + If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and the one on the diagonal will not necessarily be selected as the most correlated. In such cases, supplying the dataset and target_col is recommended which enables taking into account homophones/homographs. # Obligatory Arguments - `SChat`: the Chat or Shat matrix @@ -455,6 +469,10 @@ function eval_SC( verbose = false ) + if size(unique(SC, dims=1), 1) != size(SC, 1) + @warn "eval_SC: The C or S matrix contains duplicate vectors (usually because of homophones/homographs). Supplying the dataset and target column is recommended for a realistic evaluation. See the documentation of this function for more information." + end + l = size(SChat, 1) num_chucks = ceil(Int64, l / batch_size) verbose && begin @@ -494,7 +512,7 @@ end Assess model accuracy on the basis of the correlations of row vectors of Chat and C or Shat and S. Ideally the target words have highest correlations on the diagonal of the pertinent correlation matrices. For large datasets, pass batch_size to -process evaluation in chucks. Support homophones. +process evaluation in chunks. Support homophones. # Obligatory Arguments - `SChat::AbstractArray`: the Chat or Shat matrix @@ -617,6 +635,10 @@ end Assess model accuracy on the basis of the correlations of row vectors of Chat and C or Shat and S. Count it as correct if one of the top k candidates is correct. +!!! note + If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and it is not guaranteed that the target on the diagonal will be among the k neighbours. In particular, `eval_SC` and `eval_SC_loose` with k=1 are not guaranteed to give the same result. In such cases, supplying the dataset and `target_col` is recommended which enables taking into account homophones/homographs. + + # Obligatory Arguments - `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix - `SC::Union{SparseMatrixCSC, Matrix}`: the C or S matrix @@ -631,6 +653,14 @@ eval_SC_loose(Shat, S, k) ``` """ function eval_SC_loose(SChat, SC, k; digits=4) + + if size(unique(SC, dims=1), 1) != size(SC, 1) + @warn "eval_SC_loose: The C or S matrix contains duplicate vectors (usually because of homophones/homographs). Supplying the dataset and target column is recommended for a realistic evaluation. See the documentation of this function for more information." + if k == 1 + @warn "eval_SC_loose: You set k=1. Note that if there are duplicate vectors in the S/C matrix, it is not guaranteed that eval_SC_loose with k=1 gives the same result as eval_SC." + end + end + total = size(SChat, 1) correct = 0 rSC = cor(