Skip to content

Commit

Permalink
Merge pull request #443 from aai-institute/fix/value-init
Browse files Browse the repository at this point in the history
Fix data_names in ValuationResult.zeros()
  • Loading branch information
mdbenito authored Oct 5, 2023
2 parents ac3ed99 + 0d8893c commit 15c67ec
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Unreleased

- Fix initialization of `data_names` in `ValuationResult.zeros()`
[PR #443](https://github.com/aai-institute/pyDVL/pull/443)
- Using pytest-xdist for faster local tests
[PR #440](https://github.com/aai-institute/pyDVL/pull/440)
- Added `AntitheticPermutationSampler`
Expand Down
2 changes: 1 addition & 1 deletion src/pydvl/reporting/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def plot_ci_array(
means = np.mean(data, axis=0)
variances = np.var(data, axis=0, ddof=1)

dummy: ValuationResult[np.int_, str] = ValuationResult(
dummy: ValuationResult[np.int_, np.object_] = ValuationResult(
algorithm="dummy",
values=means,
variances=variances,
Expand Down
10 changes: 7 additions & 3 deletions src/pydvl/value/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,13 +784,17 @@ def zeros(
indices = np.arange(n_samples, dtype=np.int_)
else:
indices = np.array(indices, dtype=np.int_)

if data_names is None:
data_names = np.array(indices)
else:
data_names = np.array(data_names)

return cls(
algorithm=algorithm,
status=Status.Pending,
indices=indices,
data_names=np.array(data_names, dtype=object)
if data_names is not None
else np.empty_like(indices, dtype=object),
data_names=data_names,
values=np.zeros(len(indices)),
variances=np.zeros(len(indices)),
counts=np.zeros(len(indices), dtype=np.int_),
Expand Down

0 comments on commit 15c67ec

Please sign in to comment.