-
-
Notifications
You must be signed in to change notification settings - Fork 653
Description
The idea is to add available_device
pytest fixture into the appropriate metric tests. For example:
@pytest.mark.parametrize("n_times", range(5))
- def test_multiclass_input(n_times, test_data):
+ def test_multiclass_input(n_times, available_device, test_data):
y_pred, y, num_classes, batch_size = test_data
- cm = ConfusionMatrix(num_classes=num_classes)
+ cm = ConfusionMatrix(num_classes=num_classes, device=available_device)
+. assert cm._device == torch.device(available_device)
...
This will generate 2 more tests and run test_multiclass_input
on cpu, cuda if available, mps if available.
No need to add available_device
fixture to the tests:
- checking if an error is raised, for example
test_confusion_matrix.py::test_num_classes_wrong_input
:
def test_num_classes_wrong_input():
with pytest.raises(ValueError, match="Argument num_classes needs to be > 1"):
ConfusionMatrix(num_classes=1)
- running distributed tests, e.g. tests inside
TestDistributed
ortest_distrib_*
.
Files to update:
- test_average_precision.py
- test_classification_report.py
- test_cohen_kappa.py
- test_confusion_matrix.py
- test_cosine_similarity.py
- test_entropy.py
- test_fbeta.py
- test_frequency.py
- test_hsic.py
- test_js_divergence.py
- test_kl_divergence.py
- test_loss.py
- test_maximum_mean_discrepancy.py
- test_mean_absolute_error.py
- test_mean_pairwise_distance.py
- test_mean_squared_error.py
- test_multilabel_confusion_matrix.py
- test_mutual_information.py
- test_precision_recall_curve.py
- test_roc_auc.py
- test_roc_curve.py
- test_root_mean_squared_error.py
- test_top_k_categorical_accuracy.py
- test_object_detection_map.py
Additional lot of metrics:
- clustering / test_calinski_harabasz_score.py ... test_silhouette_score.py
- nlp / test_bleu.py ...
- regression / ...
Please, split the work into multiple PRs: one PR for updates in a single file.
This issue can't be assigned to a single person and can be tackled in collaboration.
For the questions/details, please ask here or on our discord.