You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py
+9Lines changed: 9 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -229,6 +229,7 @@ def __init__(
229
229
memory_not_causal=True, # flash attention behaves a bit more optimally if causal mask is not explicitly passed in - but if the memories perform better without a causal mask, it is necessary to have this turned on
230
230
add_write_to_next_write_mem=False, # add the write memories of previous step to the next write step - thanks to @IcarusWizard for pointing out this discrepancy
231
231
next_write_mem_stop_grad=True, # whether to stop gradient of previous read memory -> next write memory
232
+
always_have_read_memories=True, # whether to always have read memories, even on the first step, so to make the model onnx-able
232
233
resi_dual_scale=1., # in the case of overflows in fp16 on the prenorm branch, set this to a value less than 1.
0 commit comments