Skip to content

Commit 71ca0b2

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 649dfdb commit 71ca0b2

30 files changed

+455
-345
lines changed

examples/bert_score-own_model.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,15 @@ def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) ->
7979
sentence.lower().split()[:max_len] + [self.PAD_TOKEN] * (max_len - len(sentence.lower().split()))
8080
for sentence in sentences
8181
]
82-
output_dict["input_ids"] = torch.cat([
83-
torch.cat([self.word2vec[word] for word in sentence]).unsqueeze(0) for sentence in tokenized_sentences
84-
])
85-
output_dict["attention_mask"] = torch.cat([
86-
torch.tensor([1 if word != self.PAD_TOKEN else 0 for word in sentence]).unsqueeze(0)
87-
for sentence in tokenized_sentences
88-
]).long()
82+
output_dict["input_ids"] = torch.cat(
83+
[torch.cat([self.word2vec[word] for word in sentence]).unsqueeze(0) for sentence in tokenized_sentences]
84+
)
85+
output_dict["attention_mask"] = torch.cat(
86+
[
87+
torch.tensor([1 if word != self.PAD_TOKEN else 0 for word in sentence]).unsqueeze(0)
88+
for sentence in tokenized_sentences
89+
]
90+
).long()
8991

9092
return output_dict
9193

src/torchmetrics/classification/accuracy.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -504,11 +504,13 @@ def __new__( # type: ignore[misc]
504504
"""Initialize task metric."""
505505
task = ClassificationTask.from_str(task)
506506

507-
kwargs.update({
508-
"multidim_average": multidim_average,
509-
"ignore_index": ignore_index,
510-
"validate_args": validate_args,
511-
})
507+
kwargs.update(
508+
{
509+
"multidim_average": multidim_average,
510+
"ignore_index": ignore_index,
511+
"validate_args": validate_args,
512+
}
513+
)
512514

513515
if task == ClassificationTask.BINARY:
514516
return BinaryAccuracy(threshold, **kwargs)

src/torchmetrics/classification/exact_match.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -406,11 +406,13 @@ def __new__(
406406
) -> Metric:
407407
"""Initialize task metric."""
408408
task = ClassificationTaskNoBinary.from_str(task)
409-
kwargs.update({
410-
"multidim_average": multidim_average,
411-
"ignore_index": ignore_index,
412-
"validate_args": validate_args,
413-
})
409+
kwargs.update(
410+
{
411+
"multidim_average": multidim_average,
412+
"ignore_index": ignore_index,
413+
"validate_args": validate_args,
414+
}
415+
)
414416
if task == ClassificationTaskNoBinary.MULTICLASS:
415417
if not isinstance(num_classes, int):
416418
raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")

src/torchmetrics/classification/f_beta.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,11 +1075,13 @@ def __new__(
10751075
"""Initialize task metric."""
10761076
task = ClassificationTask.from_str(task)
10771077
assert multidim_average is not None # noqa: S101 # needed for mypy
1078-
kwargs.update({
1079-
"multidim_average": multidim_average,
1080-
"ignore_index": ignore_index,
1081-
"validate_args": validate_args,
1082-
})
1078+
kwargs.update(
1079+
{
1080+
"multidim_average": multidim_average,
1081+
"ignore_index": ignore_index,
1082+
"validate_args": validate_args,
1083+
}
1084+
)
10831085
if task == ClassificationTask.BINARY:
10841086
return BinaryFBetaScore(beta, threshold, **kwargs)
10851087
if task == ClassificationTask.MULTICLASS:
@@ -1138,11 +1140,13 @@ def __new__(
11381140
"""Initialize task metric."""
11391141
task = ClassificationTask.from_str(task)
11401142
assert multidim_average is not None # noqa: S101 # needed for mypy
1141-
kwargs.update({
1142-
"multidim_average": multidim_average,
1143-
"ignore_index": ignore_index,
1144-
"validate_args": validate_args,
1145-
})
1143+
kwargs.update(
1144+
{
1145+
"multidim_average": multidim_average,
1146+
"ignore_index": ignore_index,
1147+
"validate_args": validate_args,
1148+
}
1149+
)
11461150
if task == ClassificationTask.BINARY:
11471151
return BinaryF1Score(threshold, **kwargs)
11481152
if task == ClassificationTask.MULTICLASS:

src/torchmetrics/classification/hamming.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -508,11 +508,13 @@ def __new__( # type: ignore[misc]
508508
"""Initialize task metric."""
509509
task = ClassificationTask.from_str(task)
510510
assert multidim_average is not None # noqa: S101 # needed for mypy
511-
kwargs.update({
512-
"multidim_average": multidim_average,
513-
"ignore_index": ignore_index,
514-
"validate_args": validate_args,
515-
})
511+
kwargs.update(
512+
{
513+
"multidim_average": multidim_average,
514+
"ignore_index": ignore_index,
515+
"validate_args": validate_args,
516+
}
517+
)
516518
if task == ClassificationTask.BINARY:
517519
return BinaryHammingDistance(threshold, **kwargs)
518520
if task == ClassificationTask.MULTICLASS:

src/torchmetrics/classification/precision_recall.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -945,11 +945,13 @@ def __new__(
945945
) -> Metric:
946946
"""Initialize task metric."""
947947
assert multidim_average is not None # noqa: S101 # needed for mypy
948-
kwargs.update({
949-
"multidim_average": multidim_average,
950-
"ignore_index": ignore_index,
951-
"validate_args": validate_args,
952-
})
948+
kwargs.update(
949+
{
950+
"multidim_average": multidim_average,
951+
"ignore_index": ignore_index,
952+
"validate_args": validate_args,
953+
}
954+
)
953955
task = ClassificationTask.from_str(task)
954956
if task == ClassificationTask.BINARY:
955957
return BinaryPrecision(threshold, **kwargs)
@@ -1011,11 +1013,13 @@ def __new__(
10111013
"""Initialize task metric."""
10121014
task = ClassificationTask.from_str(task)
10131015
assert multidim_average is not None # noqa: S101 # needed for mypy
1014-
kwargs.update({
1015-
"multidim_average": multidim_average,
1016-
"ignore_index": ignore_index,
1017-
"validate_args": validate_args,
1018-
})
1016+
kwargs.update(
1017+
{
1018+
"multidim_average": multidim_average,
1019+
"ignore_index": ignore_index,
1020+
"validate_args": validate_args,
1021+
}
1022+
)
10191023
if task == ClassificationTask.BINARY:
10201024
return BinaryRecall(threshold, **kwargs)
10211025
if task == ClassificationTask.MULTICLASS:

src/torchmetrics/classification/specificity.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -492,11 +492,13 @@ def __new__( # type: ignore[misc]
492492
"""Initialize task metric."""
493493
task = ClassificationTask.from_str(task)
494494
assert multidim_average is not None # noqa: S101 # needed for mypy
495-
kwargs.update({
496-
"multidim_average": multidim_average,
497-
"ignore_index": ignore_index,
498-
"validate_args": validate_args,
499-
})
495+
kwargs.update(
496+
{
497+
"multidim_average": multidim_average,
498+
"ignore_index": ignore_index,
499+
"validate_args": validate_args,
500+
}
501+
)
500502
if task == ClassificationTask.BINARY:
501503
return BinarySpecificity(threshold, **kwargs)
502504
if task == ClassificationTask.MULTICLASS:

src/torchmetrics/classification/stat_scores.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -531,11 +531,13 @@ def __new__(
531531
"""Initialize task metric."""
532532
task = ClassificationTask.from_str(task)
533533
assert multidim_average is not None # noqa: S101 # needed for mypy
534-
kwargs.update({
535-
"multidim_average": multidim_average,
536-
"ignore_index": ignore_index,
537-
"validate_args": validate_args,
538-
})
534+
kwargs.update(
535+
{
536+
"multidim_average": multidim_average,
537+
"ignore_index": ignore_index,
538+
"validate_args": validate_args,
539+
}
540+
)
539541
if task == ClassificationTask.BINARY:
540542
return BinaryStatScores(threshold, **kwargs)
541543
if task == ClassificationTask.MULTICLASS:

src/torchmetrics/functional/clustering/dunn_index.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def _dunn_index_update(data: Tensor, labels: Tensor, p: float) -> Tuple[Tensor,
3939
torch.stack([a - b for a, b in combinations(centroids, 2)], dim=0), ord=p, dim=1
4040
)
4141

42-
max_intracluster_distance = torch.stack([
43-
torch.linalg.norm(ci - mu, ord=p, dim=1).max() for ci, mu in zip(clusters, centroids)
44-
])
42+
max_intracluster_distance = torch.stack(
43+
[torch.linalg.norm(ci - mu, ord=p, dim=1).max() for ci, mu in zip(clusters, centroids)]
44+
)
4545

4646
return intercluster_distance, max_intracluster_distance
4747

src/torchmetrics/functional/clustering/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,12 @@ def calculate_contingency_matrix(
154154
num_classes_target = target_classes.size(0)
155155

156156
contingency = torch.sparse_coo_tensor(
157-
torch.stack((
158-
target_idx,
159-
preds_idx,
160-
)),
157+
torch.stack(
158+
(
159+
target_idx,
160+
preds_idx,
161+
)
162+
),
161163
torch.ones(target_idx.shape[0], dtype=preds_idx.dtype, device=preds_idx.device),
162164
(
163165
num_classes_target,

0 commit comments

Comments
 (0)