Skip to content

Commit

Permalink
Fix replacement option in loss and gradnorm downsampling (#529)
Browse files Browse the repository at this point in the history
There was an unused replacement option which defaulted to True, which
caused multiple keys to be chosen, which caused the storage to crash
since it cannot deal with multiple keys in its current implementation.
However, it should not have happened in the first place.
  • Loading branch information
MaxiBoether authored Jun 19, 2024
1 parent c52a3d6 commit e6541d3
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def __init__(
assert "downsampling_ratio" in params_from_selector
self.downsampling_ratio = params_from_selector["downsampling_ratio"]

self.replacement = params_from_selector.get("replacement", True)

# The next variable is used to keep a mapping index <-> sample_id
# This is needed since the data selection policy works on indexes (the policy does not care what the sample_id
# is, it simply stores its score in a vector/matrix) but for retrieving again the data we need somehow to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def select_points(self) -> tuple[list[int], torch.Tensor]:
probabilities = torch.cat(self.probabilities, dim=0)
probabilities = probabilities / probabilities.sum()

downsampled_idxs = torch.multinomial(probabilities, target_size, replacement=self.replacement)
downsampled_idxs = torch.multinomial(probabilities, target_size, replacement=False)

# lower probability, higher weight to reduce the variance
weights = 1.0 / (self.number_of_points_seen * probabilities[downsampled_idxs])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def select_points(self) -> tuple[list[int], torch.Tensor]:
probabilities = torch.cat(self.probabilities, dim=0)
probabilities = probabilities / probabilities.sum()

downsampled_idxs = torch.multinomial(probabilities, target_size, replacement=self.replacement)
downsampled_idxs = torch.multinomial(probabilities, target_size, replacement=False)

# lower probability, higher weight to reduce the variance
weights = 1.0 / (self.number_of_points_seen * probabilities[downsampled_idxs])
Expand Down

0 comments on commit e6541d3

Please sign in to comment.