diff --git a/src/eval.jl b/src/eval.jl index ff5f9d7..4392266 100644 --- a/src/eval.jl +++ b/src/eval.jl @@ -25,7 +25,10 @@ function eval_SC_loose end """ accuracy_comprehension(S, Shat, data) -Evaluate comprehension accuracy. +Evaluate comprehension accuracy for training data. + +!!! note + In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! See below for more information. # Obligatory Arguments - `S::Matrix`: the (gold standard) S matrix @@ -47,16 +50,19 @@ accuracy_comprehension( base=[:Lexeme], inflections=[:Person, :Number, :Tense, :Voice, :Mood] ) - -accuracy_comprehension( - S_val, - Shat_val, - latin_train, - target_col=:Words, - base=["Lexeme"], - inflections=[:Person, :Number, :Tense, :Voice, :Mood] - ) ``` + +# Note +In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! +Consider the following example: The wordform "Äpfel" in German can be nominative plural, genitive plural and accusative plural. +Let's assume we have a dataset in which "Äpfel" occurs in all three case/number combinations (i.e. there are homographs). +If all these wordforms have the same semantic vectors (e.g. because they are derived from word2vec or fasttext which typically +have a single vector per unique wordform), the predicted semantic vector of the wordform "Äpfel" will be equally correlated +with all three case/number combinations in the dataset. In such cases, while the algorithm in this function can unambiguously +conclude that the correct surface form "Äpfel" was comprehended, which of the three possible rows is the correct one will be +picked somewhat non-deterministically (see https://docs.julialang.org/en/v1/base/collections/#Base.argmax). It is thus possible +that the algorithm will then use the genitive plural instead of the intended nominative plural as the ground plural, and will +report that "case" was comprehended incorrectly. """ function accuracy_comprehension( S, @@ -78,10 +84,16 @@ function accuracy_comprehension( dfr.r_target = corMat[diagind(corMat)] dfr.correct = [dfr.target[i] == dfr.form[i] for i = 1:size(dfr, 1)] + if length(data[:, target_col]) != length(Set(data[:, target_col])) + @warn "accuracy_comprehension: This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information." + end + if !isnothing(inflections) all_features = vcat(base, inflections) - else + elseif !isnothing(base) all_features = base + else + all_features = [] end for f in all_features @@ -110,7 +122,11 @@ end inflections = nothing, ) -Evaluate comprehension accuracy. +Evaluate comprehension accuracy for validation data. + +!!! note + In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! See below for more information. + # Obligatory Arguments - `S_val::Matrix`: the (gold standard) S matrix of the validation data @@ -137,6 +153,18 @@ accuracy_comprehension( inflections=[:Person, :Number, :Tense, :Voice, :Mood] ) ``` + +# Note +In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! +Consider the following example: The wordform "Äpfel" in German can be nominative plural, genitive plural and accusative plural. +Let's assume we have a dataset in which "Äpfel" occurs in all three case/number combinations (i.e. there are homographs). +If all these wordforms have the same semantic vectors (e.g. because they are derived from word2vec or fasttext which typically +have a single vector per unique wordform), the predicted semantic vector of the wordform "Äpfel" will be equally correlated +with all three case/number combinations in the dataset. In such cases, while the algorithm in this function can unambiguously +conclude that the correct surface form "Äpfel" was comprehended, which of the three possible rows is the correct one will be +picked somewhat non-deterministically (see https://docs.julialang.org/en/v1/base/collections/#Base.argmax). It is thus possible +that the algorithm will then use the genitive plural instead of the intended nominative plural as the ground plural, and will +report that "case" was comprehended incorrectly. """ function accuracy_comprehension( S_val, @@ -160,6 +188,10 @@ function accuracy_comprehension( append!(data_combined, data_train, promote=true) + if length(data_combined[:, target_col]) != length(Set(data_combined[:, target_col])) + @warn "accuracy_comprehension: This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information." + end + corMat = cor(Shat_val, S, dims = 2) top_index = [i[2] for i in argmax(corMat, dims = 2)] @@ -200,6 +232,9 @@ Assess model accuracy on the basis of the correlations of row vectors of Chat an C or Shat and S. Ideally the target words have highest correlations on the diagonal of the pertinent correlation matrices. +!!! note + If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and the one on the diagonal will not necessarily be selected as the most correlated. In such cases, supplying the dataset and `target_col` is recommended which enables taking into account homophones/homographs. + # Obligatory Arguments - `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix - `SC::Union{SparseMatrixCSC, Matrix}`: the C or S matrix @@ -216,6 +251,11 @@ eval_SC(Shat_val, S_val) ``` """ function eval_SC(SChat::AbstractArray, SC::AbstractArray; digits=4, R=false) + + if size(unique(SC, dims=1), 1) != size(SC, 1) + @warn "eval_SC: The C or S matrix contains duplicate vectors (usually because of homophones/homographs). Supplying the dataset and target column is recommended for a realistic evaluation. See the documentation of this function for more information." + end + rSC = cor( convert(Matrix{Float64}, SChat), convert(Matrix{Float64}, SC), @@ -241,6 +281,9 @@ of the pertinent correlation matrices. The order is important. The fist gold standard matrix has to be corresponing to the SChat matrix, such as `eval_SC(Shat_train, S_train, S_val)` or `eval_SC(Shat_val, S_val, S_train)` +!!! note + If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and the one on the diagonal will not necessarily be selected as the most correlated. In such cases, supplying the dataset and target_col is recommended which enables taking into account homophones/homographs. + # Obligatory Arguments - `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix - `SC::Union{SparseMatrixCSC, Matrix}`: the training/validation C or S matrix @@ -395,7 +438,10 @@ end Assess model accuracy on the basis of the correlations of row vectors of Chat and C or Shat and S. Ideally the target words have highest correlations on the diagonal of the pertinent correlation matrices. For large datasets, pass batch_size to -process evaluation in chucks. +process evaluation in chunks. + +!!! note + If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and the one on the diagonal will not necessarily be selected as the most correlated. In such cases, supplying the dataset and target_col is recommended which enables taking into account homophones/homographs. # Obligatory Arguments - `SChat`: the Chat or Shat matrix @@ -423,6 +469,10 @@ function eval_SC( verbose = false ) + if size(unique(SC, dims=1), 1) != size(SC, 1) + @warn "eval_SC: The C or S matrix contains duplicate vectors (usually because of homophones/homographs). Supplying the dataset and target column is recommended for a realistic evaluation. See the documentation of this function for more information." + end + l = size(SChat, 1) num_chucks = ceil(Int64, l / batch_size) verbose && begin @@ -435,7 +485,7 @@ function eval_SC( # for first parts for j = 1:num_chucks-1 - correct += eval_SC_chucks( + correct += eval_SC_chunks( SChat_d, SC_d, (j - 1) * batch_size + 1, @@ -445,7 +495,7 @@ function eval_SC( verbose && ProgressMeter.next!(pb) end # for last part - correct += eval_SC_chucks( + correct += eval_SC_chunks( SChat_d, SC_d, (num_chucks - 1) * batch_size + 1, @@ -462,7 +512,7 @@ end Assess model accuracy on the basis of the correlations of row vectors of Chat and C or Shat and S. Ideally the target words have highest correlations on the diagonal of the pertinent correlation matrices. For large datasets, pass batch_size to -process evaluation in chucks. Support homophones. +process evaluation in chunks. Support homophones. # Obligatory Arguments - `SChat::AbstractArray`: the Chat or Shat matrix @@ -504,7 +554,7 @@ function eval_SC( # for first parts for j = 1:num_chucks-1 - correct += eval_SC_chucks( + correct += eval_SC_chunks( SChat_d, SC_d, (j - 1) * batch_size + 1, @@ -516,7 +566,7 @@ function eval_SC( verbose && ProgressMeter.next!(pb) end # for last part - correct += eval_SC_chucks( + correct += eval_SC_chunks( SChat_d, SC_d, (num_chucks - 1) * batch_size + 1, @@ -529,13 +579,18 @@ function eval_SC( round(correct / l, digits=digits) end -function eval_SC_chucks(SChat, SC, s, e, batch_size) +function eval_SC_chunks(SChat, SC, s, e, batch_size) rSC = cor(SChat[s:e, :], SC, dims = 2) v = [(rSC[i[1], i[1]+s-1] == rSC[i]) ? 1 : 0 for i in argmax(rSC, dims = 2)] sum(v) end -function eval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col) +function eval_SC_chucks(SChat, SC, s, e, batch_size) + @warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks" + eval_SC_chunks(SChat, SC, s, e, batch_size) +end + +function eval_SC_chunks(SChat, SC, s, e, batch_size, data, target_col) rSC = cor(SChat[s:e, :], SC, dims = 2) v = [ data[i[1]+s-1, target_col] == data[i[2], target_col] ? 1 : 0 @@ -544,13 +599,23 @@ function eval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col) sum(v) end -function eval_SC_chucks(SChat, SC, s, batch_size) +function eval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col) + @warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks" + eval_SC_chunks(SChat, SC, s, e, batch_size, data, target_col) +end + +function eval_SC_chunks(SChat, SC, s, batch_size) rSC = cor(SChat[s:end, :], SC, dims = 2) v = [(rSC[i[1], i[1]+s-1] == rSC[i]) ? 1 : 0 for i in argmax(rSC, dims = 2)] sum(v) end -function eval_SC_chucks(SChat, SC, s, batch_size, data, target_col) +function eval_SC_chucks(SChat, SC, s, batch_size) + @warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks" + eval_SC_chunks(SChat, SC, s, batch_size) +end + +function eval_SC_chunks(SChat, SC, s, batch_size, data, target_col) rSC = cor(SChat[s:end, :], SC, dims = 2) v = [ data[i[1]+s-1, target_col] == data[i[2], target_col] ? 1 : 0 @@ -559,12 +624,21 @@ function eval_SC_chucks(SChat, SC, s, batch_size, data, target_col) sum(v) end +function eval_SC_chucks(SChat, SC, s, batch_size, data, target_col) + @warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks" + eval_SC_chunks(SChat, SC, s, batch_size, data, target_col) +end + """ eval_SC_loose(SChat, SC, k) Assess model accuracy on the basis of the correlations of row vectors of Chat and C or Shat and S. Count it as correct if one of the top k candidates is correct. +!!! note + If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and it is not guaranteed that the target on the diagonal will be among the k neighbours. In particular, `eval_SC` and `eval_SC_loose` with k=1 are not guaranteed to give the same result. In such cases, supplying the dataset and `target_col` is recommended which enables taking into account homophones/homographs. + + # Obligatory Arguments - `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix - `SC::Union{SparseMatrixCSC, Matrix}`: the C or S matrix @@ -579,6 +653,14 @@ eval_SC_loose(Shat, S, k) ``` """ function eval_SC_loose(SChat, SC, k; digits=4) + + if size(unique(SC, dims=1), 1) != size(SC, 1) + @warn "eval_SC_loose: The C or S matrix contains duplicate vectors (usually because of homophones/homographs). Supplying the dataset and target column is recommended for a realistic evaluation. See the documentation of this function for more information." + if k == 1 + @warn "eval_SC_loose: You set k=1. Note that if there are duplicate vectors in the S/C matrix, it is not guaranteed that eval_SC_loose with k=1 gives the same result as eval_SC." + end + end + total = size(SChat, 1) correct = 0 rSC = cor( @@ -588,8 +670,7 @@ function eval_SC_loose(SChat, SC, k; digits=4) ) for i = 1:total - p = sortperm(rSC[i, :], rev = true) - p = p[1:k, :] + p = partialsortperm(rSC[i, :], 1:k, rev = true) if i in p correct += 1 end @@ -629,8 +710,7 @@ function eval_SC_loose(SChat, SC, k, data, target_col; digits=4) ) for i = 1:total - p = sortperm(rSC[i, :], rev = true) - p = p[1:k] + p = partialsortperm(rSC[i, :], 1:k, rev = true) if i in p correct += 1 else diff --git a/src/make_semantic_matrix.jl b/src/make_semantic_matrix.jl index 5d2747f..948437d 100644 --- a/src/make_semantic_matrix.jl +++ b/src/make_semantic_matrix.jl @@ -997,7 +997,9 @@ function make_combined_L_matrix( inflections; ncol = ncol, sd_base_mean = sd_base_mean, + sd_inflection_mean = sd_inflection_mean, sd_base = sd_base, + sd_inflection = sd_inflection, seed = seed, isdeep = isdeep, ) @@ -1055,6 +1057,8 @@ function make_combined_L_matrix( ncol = ncol, sd_base_mean = sd_base_mean, sd_base = sd_base, + sd_inflection_mean = sd_inflection_mean, + sd_inflection = sd_inflection, seed = seed, isdeep = isdeep, ) @@ -1224,6 +1228,8 @@ function make_combined_S_matrix( ncol = ncol, sd_base_mean = sd_base_mean, sd_base = sd_base, + sd_inflection_mean = sd_inflection_mean, + sd_inflection = sd_inflection, seed = seed, isdeep = isdeep, ) @@ -1298,6 +1304,8 @@ function make_combined_S_matrix( ncol = ncol, sd_base_mean = sd_base_mean, sd_base = sd_base, + sd_inflection_mean = sd_inflection_mean, + sd_inflection = sd_inflection, seed = seed, isdeep = isdeep, )