Skip to content

Commit

Permalink
Merge remote-tracking branch 'refs/remotes/origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
MariaHei committed Apr 11, 2024
2 parents eb740ec + 6f7a59c commit 50320af
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 26 deletions.
132 changes: 106 additions & 26 deletions src/eval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ function eval_SC_loose end
"""
accuracy_comprehension(S, Shat, data)
Evaluate comprehension accuracy.
Evaluate comprehension accuracy for training data.
!!! note
In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! See below for more information.
# Obligatory Arguments
- `S::Matrix`: the (gold standard) S matrix
Expand All @@ -47,16 +50,19 @@ accuracy_comprehension(
base=[:Lexeme],
inflections=[:Person, :Number, :Tense, :Voice, :Mood]
)
accuracy_comprehension(
S_val,
Shat_val,
latin_train,
target_col=:Words,
base=["Lexeme"],
inflections=[:Person, :Number, :Tense, :Voice, :Mood]
)
```
# Note
In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading!
Consider the following example: The wordform "Äpfel" in German can be nominative plural, genitive plural and accusative plural.
Let's assume we have a dataset in which "Äpfel" occurs in all three case/number combinations (i.e. there are homographs).
If all these wordforms have the same semantic vectors (e.g. because they are derived from word2vec or fasttext which typically
have a single vector per unique wordform), the predicted semantic vector of the wordform "Äpfel" will be equally correlated
with all three case/number combinations in the dataset. In such cases, while the algorithm in this function can unambiguously
conclude that the correct surface form "Äpfel" was comprehended, which of the three possible rows is the correct one will be
picked somewhat non-deterministically (see https://docs.julialang.org/en/v1/base/collections/#Base.argmax). It is thus possible
that the algorithm will then use the genitive plural instead of the intended nominative plural as the ground plural, and will
report that "case" was comprehended incorrectly.
"""
function accuracy_comprehension(
S,
Expand All @@ -78,10 +84,16 @@ function accuracy_comprehension(
dfr.r_target = corMat[diagind(corMat)]
dfr.correct = [dfr.target[i] == dfr.form[i] for i = 1:size(dfr, 1)]

if length(data[:, target_col]) != length(Set(data[:, target_col]))
@warn "accuracy_comprehension: This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information."
end

if !isnothing(inflections)
all_features = vcat(base, inflections)
else
elseif !isnothing(base)
all_features = base
else
all_features = []
end

for f in all_features
Expand Down Expand Up @@ -110,7 +122,11 @@ end
inflections = nothing,
)
Evaluate comprehension accuracy.
Evaluate comprehension accuracy for validation data.
!!! note
In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! See below for more information.
# Obligatory Arguments
- `S_val::Matrix`: the (gold standard) S matrix of the validation data
Expand All @@ -137,6 +153,18 @@ accuracy_comprehension(
inflections=[:Person, :Number, :Tense, :Voice, :Mood]
)
```
# Note
In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading!
Consider the following example: The wordform "Äpfel" in German can be nominative plural, genitive plural and accusative plural.
Let's assume we have a dataset in which "Äpfel" occurs in all three case/number combinations (i.e. there are homographs).
If all these wordforms have the same semantic vectors (e.g. because they are derived from word2vec or fasttext which typically
have a single vector per unique wordform), the predicted semantic vector of the wordform "Äpfel" will be equally correlated
with all three case/number combinations in the dataset. In such cases, while the algorithm in this function can unambiguously
conclude that the correct surface form "Äpfel" was comprehended, which of the three possible rows is the correct one will be
picked somewhat non-deterministically (see https://docs.julialang.org/en/v1/base/collections/#Base.argmax). It is thus possible
that the algorithm will then use the genitive plural instead of the intended nominative plural as the ground plural, and will
report that "case" was comprehended incorrectly.
"""
function accuracy_comprehension(
S_val,
Expand All @@ -160,6 +188,10 @@ function accuracy_comprehension(

append!(data_combined, data_train, promote=true)

if length(data_combined[:, target_col]) != length(Set(data_combined[:, target_col]))
@warn "accuracy_comprehension: This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information."
end

corMat = cor(Shat_val, S, dims = 2)
top_index = [i[2] for i in argmax(corMat, dims = 2)]

Expand Down Expand Up @@ -200,6 +232,9 @@ Assess model accuracy on the basis of the correlations of row vectors of Chat an
C or Shat and S. Ideally the target words have highest correlations on the diagonal
of the pertinent correlation matrices.
!!! note
If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and the one on the diagonal will not necessarily be selected as the most correlated. In such cases, supplying the dataset and `target_col` is recommended which enables taking into account homophones/homographs.
# Obligatory Arguments
- `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix
- `SC::Union{SparseMatrixCSC, Matrix}`: the C or S matrix
Expand All @@ -216,6 +251,11 @@ eval_SC(Shat_val, S_val)
```
"""
function eval_SC(SChat::AbstractArray, SC::AbstractArray; digits=4, R=false)

if size(unique(SC, dims=1), 1) != size(SC, 1)
@warn "eval_SC: The C or S matrix contains duplicate vectors (usually because of homophones/homographs). Supplying the dataset and target column is recommended for a realistic evaluation. See the documentation of this function for more information."
end

rSC = cor(
convert(Matrix{Float64}, SChat),
convert(Matrix{Float64}, SC),
Expand All @@ -241,6 +281,9 @@ of the pertinent correlation matrices.
The order is important. The fist gold standard matrix has to be corresponing
to the SChat matrix, such as `eval_SC(Shat_train, S_train, S_val)` or `eval_SC(Shat_val, S_val, S_train)`
!!! note
If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and the one on the diagonal will not necessarily be selected as the most correlated. In such cases, supplying the dataset and target_col is recommended which enables taking into account homophones/homographs.
# Obligatory Arguments
- `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix
- `SC::Union{SparseMatrixCSC, Matrix}`: the training/validation C or S matrix
Expand Down Expand Up @@ -395,7 +438,10 @@ end
Assess model accuracy on the basis of the correlations of row vectors of Chat and
C or Shat and S. Ideally the target words have highest correlations on the diagonal
of the pertinent correlation matrices. For large datasets, pass batch_size to
process evaluation in chucks.
process evaluation in chunks.
!!! note
If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and the one on the diagonal will not necessarily be selected as the most correlated. In such cases, supplying the dataset and target_col is recommended which enables taking into account homophones/homographs.
# Obligatory Arguments
- `SChat`: the Chat or Shat matrix
Expand Down Expand Up @@ -423,6 +469,10 @@ function eval_SC(
verbose = false
)

if size(unique(SC, dims=1), 1) != size(SC, 1)
@warn "eval_SC: The C or S matrix contains duplicate vectors (usually because of homophones/homographs). Supplying the dataset and target column is recommended for a realistic evaluation. See the documentation of this function for more information."
end

l = size(SChat, 1)
num_chucks = ceil(Int64, l / batch_size)
verbose && begin
Expand All @@ -435,7 +485,7 @@ function eval_SC(

# for first parts
for j = 1:num_chucks-1
correct += eval_SC_chucks(
correct += eval_SC_chunks(
SChat_d,
SC_d,
(j - 1) * batch_size + 1,
Expand All @@ -445,7 +495,7 @@ function eval_SC(
verbose && ProgressMeter.next!(pb)
end
# for last part
correct += eval_SC_chucks(
correct += eval_SC_chunks(
SChat_d,
SC_d,
(num_chucks - 1) * batch_size + 1,
Expand All @@ -462,7 +512,7 @@ end
Assess model accuracy on the basis of the correlations of row vectors of Chat and
C or Shat and S. Ideally the target words have highest correlations on the diagonal
of the pertinent correlation matrices. For large datasets, pass batch_size to
process evaluation in chucks. Support homophones.
process evaluation in chunks. Support homophones.
# Obligatory Arguments
- `SChat::AbstractArray`: the Chat or Shat matrix
Expand Down Expand Up @@ -504,7 +554,7 @@ function eval_SC(

# for first parts
for j = 1:num_chucks-1
correct += eval_SC_chucks(
correct += eval_SC_chunks(
SChat_d,
SC_d,
(j - 1) * batch_size + 1,
Expand All @@ -516,7 +566,7 @@ function eval_SC(
verbose && ProgressMeter.next!(pb)
end
# for last part
correct += eval_SC_chucks(
correct += eval_SC_chunks(
SChat_d,
SC_d,
(num_chucks - 1) * batch_size + 1,
Expand All @@ -529,13 +579,18 @@ function eval_SC(
round(correct / l, digits=digits)
end

function eval_SC_chucks(SChat, SC, s, e, batch_size)
function eval_SC_chunks(SChat, SC, s, e, batch_size)
rSC = cor(SChat[s:e, :], SC, dims = 2)
v = [(rSC[i[1], i[1]+s-1] == rSC[i]) ? 1 : 0 for i in argmax(rSC, dims = 2)]
sum(v)
end

function eval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col)
function eval_SC_chucks(SChat, SC, s, e, batch_size)
@warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks"
eval_SC_chunks(SChat, SC, s, e, batch_size)
end

function eval_SC_chunks(SChat, SC, s, e, batch_size, data, target_col)
rSC = cor(SChat[s:e, :], SC, dims = 2)
v = [
data[i[1]+s-1, target_col] == data[i[2], target_col] ? 1 : 0
Expand All @@ -544,13 +599,23 @@ function eval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col)
sum(v)
end

function eval_SC_chucks(SChat, SC, s, batch_size)
function eval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col)
@warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks"
eval_SC_chunks(SChat, SC, s, e, batch_size, data, target_col)
end

function eval_SC_chunks(SChat, SC, s, batch_size)
rSC = cor(SChat[s:end, :], SC, dims = 2)
v = [(rSC[i[1], i[1]+s-1] == rSC[i]) ? 1 : 0 for i in argmax(rSC, dims = 2)]
sum(v)
end

function eval_SC_chucks(SChat, SC, s, batch_size, data, target_col)
function eval_SC_chucks(SChat, SC, s, batch_size)
@warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks"
eval_SC_chunks(SChat, SC, s, batch_size)
end

function eval_SC_chunks(SChat, SC, s, batch_size, data, target_col)
rSC = cor(SChat[s:end, :], SC, dims = 2)
v = [
data[i[1]+s-1, target_col] == data[i[2], target_col] ? 1 : 0
Expand All @@ -559,12 +624,21 @@ function eval_SC_chucks(SChat, SC, s, batch_size, data, target_col)
sum(v)
end

function eval_SC_chucks(SChat, SC, s, batch_size, data, target_col)
@warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks"
eval_SC_chunks(SChat, SC, s, batch_size, data, target_col)
end

"""
eval_SC_loose(SChat, SC, k)
Assess model accuracy on the basis of the correlations of row vectors of Chat and
C or Shat and S. Count it as correct if one of the top k candidates is correct.
!!! note
If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and it is not guaranteed that the target on the diagonal will be among the k neighbours. In particular, `eval_SC` and `eval_SC_loose` with k=1 are not guaranteed to give the same result. In such cases, supplying the dataset and `target_col` is recommended which enables taking into account homophones/homographs.
# Obligatory Arguments
- `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix
- `SC::Union{SparseMatrixCSC, Matrix}`: the C or S matrix
Expand All @@ -579,6 +653,14 @@ eval_SC_loose(Shat, S, k)
```
"""
function eval_SC_loose(SChat, SC, k; digits=4)

if size(unique(SC, dims=1), 1) != size(SC, 1)
@warn "eval_SC_loose: The C or S matrix contains duplicate vectors (usually because of homophones/homographs). Supplying the dataset and target column is recommended for a realistic evaluation. See the documentation of this function for more information."
if k == 1
@warn "eval_SC_loose: You set k=1. Note that if there are duplicate vectors in the S/C matrix, it is not guaranteed that eval_SC_loose with k=1 gives the same result as eval_SC."
end
end

total = size(SChat, 1)
correct = 0
rSC = cor(
Expand All @@ -588,8 +670,7 @@ function eval_SC_loose(SChat, SC, k; digits=4)
)

for i = 1:total
p = sortperm(rSC[i, :], rev = true)
p = p[1:k, :]
p = partialsortperm(rSC[i, :], 1:k, rev = true)
if i in p
correct += 1
end
Expand Down Expand Up @@ -629,8 +710,7 @@ function eval_SC_loose(SChat, SC, k, data, target_col; digits=4)
)

for i = 1:total
p = sortperm(rSC[i, :], rev = true)
p = p[1:k]
p = partialsortperm(rSC[i, :], 1:k, rev = true)
if i in p
correct += 1
else
Expand Down
8 changes: 8 additions & 0 deletions src/make_semantic_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,9 @@ function make_combined_L_matrix(
inflections;
ncol = ncol,
sd_base_mean = sd_base_mean,
sd_inflection_mean = sd_inflection_mean,
sd_base = sd_base,
sd_inflection = sd_inflection,
seed = seed,
isdeep = isdeep,
)
Expand Down Expand Up @@ -1055,6 +1057,8 @@ function make_combined_L_matrix(
ncol = ncol,
sd_base_mean = sd_base_mean,
sd_base = sd_base,
sd_inflection_mean = sd_inflection_mean,
sd_inflection = sd_inflection,
seed = seed,
isdeep = isdeep,
)
Expand Down Expand Up @@ -1224,6 +1228,8 @@ function make_combined_S_matrix(
ncol = ncol,
sd_base_mean = sd_base_mean,
sd_base = sd_base,
sd_inflection_mean = sd_inflection_mean,
sd_inflection = sd_inflection,
seed = seed,
isdeep = isdeep,
)
Expand Down Expand Up @@ -1298,6 +1304,8 @@ function make_combined_S_matrix(
ncol = ncol,
sd_base_mean = sd_base_mean,
sd_base = sd_base,
sd_inflection_mean = sd_inflection_mean,
sd_inflection = sd_inflection,
seed = seed,
isdeep = isdeep,
)
Expand Down

0 comments on commit 50320af

Please sign in to comment.