Skip to content

Commit 90de2ac

Browse files
committed
address #19 by allowing for an option to attend to raw read memory positional embeddings on first step
1 parent 3be7d43 commit 90de2ac

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def __init__(
229229
memory_not_causal = True, # flash attention behaves a bit more optimally if causal mask is not explicitly passed in - but if the memories perform better without a causal mask, it is necessary to have this turned on
230230
add_write_to_next_write_mem = False, # add the write memories of previous step to the next write step - thanks to @IcarusWizard for pointing out this discrepancy
231231
next_write_mem_stop_grad = True, # whether to stop gradient of previous read memory -> next write memory
232+
always_have_read_memories = True, # whether to always have read memories, even on the first step, so to make the model onnx-able
232233
resi_dual_scale = 1., # in the case of overflows in fp16 on the prenorm branch, set this to a value less than 1.
233234
):
234235
super().__init__()
@@ -306,6 +307,11 @@ def __init__(
306307
self.add_write_to_next_write_mem = add_write_to_next_write_mem
307308
self.next_write_mem_stop_grad = next_write_mem_stop_grad
308309

310+
# allow for attending to raw read memory positional embeddings on first step
311+
# hack to make it onnx-able and should not hurt
312+
313+
self.always_have_read_memories = always_have_read_memories
314+
309315
def init_memory(self, batch):
310316
return repeat(self.memory_tokens, 'm d -> b m d', b = batch)
311317

@@ -350,6 +356,9 @@ def forward(
350356
if exists(read_memories):
351357
read_mem_length = mem_length
352358
read_memories = read_memories + self.read_memory_emb
359+
elif self.always_have_read_memories:
360+
read_mem_length = mem_length
361+
read_memories = repeat(self.read_memory_emb, 'n d -> b n d', b = b)
353362
else:
354363
read_mem_length = 0
355364
read_memories = x[:, 0:0]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'recurrent-memory-transformer-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.5.3',
6+
version = '0.5.4',
77
license='MIT',
88
description = 'Recurrent Memory Transformer - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)