Skip to content

Commit dabfeaf

Browse files
committed
add one-hot encoding
1 parent 9800c6f commit dabfeaf

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

batcher.lua

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ function char_to_ints(text)
3030
return alphabet, encoded
3131
end
3232

33+
-- function for one hot encoding
34+
function ints_to_one_hot(ints, width)
35+
local height = ints:size()[1]
36+
local zeros = torch.zeros(height, width)
37+
local indices = ints:view(-1, 1):long()
38+
local one_hot = zeros:scatter(2, indices, 1)
39+
40+
return one_hot
41+
end
42+
3343
-- function to generate chunks of encoded data based on chunk size
3444
function generate_chunks(encoded_text, chunk_size)
3545
local n_chunks = math.floor(encoded_text:size()[1]/chunk_size)

0 commit comments

Comments
 (0)