Skip to content

Commit 8f9c3ad

Browse files
committed
reinject the write memory positions
1 parent 8b82aaa commit 8f9c3ad

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,10 +337,10 @@ def forward(
337337

338338
# prepare write memories, as in paper
339339

340+
write_memories = self.init_memory(b)
341+
340342
if exists(read_memories) and self.add_write_to_next_write_mem:
341-
write_memories = read_memories
342-
else:
343-
write_memories = self.init_memory(b)
343+
write_memories = write_memories + read_memories
344344

345345
# prepare read memories
346346

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'recurrent-memory-transformer-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.5.1',
6+
version = '0.5.2',
77
license='MIT',
88
description = 'Recurrent Memory Transformer - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)