Skip to content

Commit 2de23fd

Browse files
authored
Ensure hamming similarity inputs are vsa tensors (#159)
1 parent 83fdf34 commit 2de23fd

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

torchhd/functional.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,6 +1018,9 @@ def hamming_similarity(input: VSATensor, others: VSATensor) -> LongTensor:
10181018
[5, 3, 6]])
10191019
10201020
"""
1021+
input = ensure_vsa_tensor(input)
1022+
others = ensure_vsa_tensor(others)
1023+
10211024
if input.dim() > 1 and others.dim() > 1:
10221025
equals = input.unsqueeze(-2) == others.unsqueeze(-3)
10231026
return torch.sum(equals, dim=-1, dtype=torch.long)

0 commit comments

Comments
 (0)