Skip to content

Commit ab28054

Browse files
committed
y,i=torch.max(x,2) performs the max operation across rows.
1 parent 1fa8811 commit ab28054

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

encoding.lua

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ end
5555
function one_hot_to_ints(ont_hot)
5656
-- y,i=torch.max(x,1) returns the largest element in each column (across
5757
-- rows) of x, and a tensor i of their corresponding indices in x.
58-
local _, ints = torch.max(one_hot:t(), 1)
58+
-- y,i=torch.max(x,2) performs the max operation across rows.
59+
local _, ints = torch.max(one_hot, 2)
5960
return ints
6061
end
6162

0 commit comments

Comments
 (0)