Skip to content

Commit

Permalink
apply suggested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
srzeszut committed Oct 27, 2024
1 parent 520633a commit 926a1c7
Showing 1 changed file with 29 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defmodule Scholar.Impute.KNNImputer do
defmodule Scholar.Impute.KNNImputter do
@moduledoc """
Imputer for completing missing values using k-Nearest Neighbors.
Expand Down Expand Up @@ -46,21 +46,22 @@ defmodule Scholar.Impute.KNNImputer do
[`Nx.Constant.nan/0`](https://hexdocs.pm/nx/Nx.Constants.html#nan/0) values.
## Examples
iex> x = Nx.tensor([[40.0, 2.0],[4.0, 5.0],[7.0, :nan],[:nan, 8.0],[11.0, 11.0]])
iex> Scholar.Impute.KNNImputer.fit(x, number_of_neighbors: 2)
%Scholar.Impute.KNNImputer{
statistics: #Nx.Tensor<
f32[5][2]
[
[NaN, NaN],
[NaN, NaN],
[NaN, 8.0],
[7.5, NaN],
[NaN, NaN]
]
>,
missing_values: :nan
}
iex> x = Nx.tensor([[40.0, 2.0],[4.0, 5.0],[7.0, :nan],[:nan, 8.0],[11.0, 11.0]])
iex> Scholar.Impute.KNNImputter.fit(x, number_of_neighbors: 2)
%Scholar.Impute.KNNImputter{
statistics: Nx.tensor(
[
[NaN, NaN],
[NaN, NaN],
[NaN, 8.0],
[7.5, NaN],
[NaN, NaN]
]
),
missing_values: :nan
}
"""

deftransform fit(x, opts \\ []) do
Expand Down Expand Up @@ -121,8 +122,8 @@ defmodule Scholar.Impute.KNNImputer do
## Examples
iex> x = Nx.tensor([[40.0, 2.0],[4.0, 5.0],[7.0, :nan],[:nan, 8.0],[11.0, 11.0]])
iex> imputer = Scholar.Impute.KNNImputer.fit(x, number_of_neighbors: 2)
iex> Scholar.Impute.KNNImputer.transform(imputer, x)
iex> imputer = Scholar.Impute.KNNImputter.fit(x, number_of_neighbors: 2)
iex> Scholar.Impute.KNNImputter.transform(imputer, x)
Nx.tensor(
[
[40.0, 2.0],
Expand Down Expand Up @@ -191,15 +192,15 @@ defmodule Scholar.Impute.KNNImputer do
Nx.less(i, rows) do
potential_donor = x[i]

if i == nan_row do
distance = Nx.Constants.infinity({:f, 32})
row_distances = Nx.indexed_put(row_distances, Nx.new_axis(i, 0), distance)
{{i + 1, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances}
else
distance = nan_euclidian(row_with_value_to_fill, nan_col, potential_donor)
row_distances = Nx.indexed_put(row_distances, Nx.new_axis(i, 0), distance)
{{i + 1, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances}
end
distance =
if i == nan_row do
Nx.Constants.infinity({:f, 32})
else
nan_euclidian(row_with_value_to_fill, nan_col, potential_donor)
end

row_distances = Nx.indexed_put(row_distances, Nx.new_axis(i, 0), distance)
{{i + 1, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances}
end

{_, indices} = Nx.top_k(-row_distances, k: num_neighbors)
Expand All @@ -221,7 +222,7 @@ defmodule Scholar.Impute.KNNImputer do

# if potential neighbor has nan in nan_col, we don't want to calculate distance and the case if potential_neighbour is the row to impute
{potential_neighbor} =
if potential_neighbor[nan_col] == Nx.Constants.nan() do
if Nx.is_nan(potential_neighbor[nan_col]) do
potential_neighbor = Nx.broadcast(Nx.Constants.infinity({:f, 32}), potential_neighbor)
{potential_neighbor}
else
Expand Down

0 comments on commit 926a1c7

Please sign in to comment.