From e5a81bdab451ac14558a0115409bc017fadce3ae Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 12 Dec 2021 09:52:09 -0800 Subject: [PATCH] fix a bug with weight tying across layers and number of self attention transformer blocks > 1, thanks to @yuanmao --- perceiver_pytorch/perceiver_pytorch.py | 19 ++++++++++--------- setup.py | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/perceiver_pytorch/perceiver_pytorch.py b/perceiver_pytorch/perceiver_pytorch.py index 79d651f..deeec5c 100644 --- a/perceiver_pytorch/perceiver_pytorch.py +++ b/perceiver_pytorch/perceiver_pytorch.py @@ -17,16 +17,17 @@ def default(val, d): return val if exists(val) else d def cache_fn(f): - cache = None + cache = dict() @wraps(f) - def cached_fn(*args, _cache = True, **kwargs): + def cached_fn(*args, _cache = True, key = None, **kwargs): if not _cache: return f(*args, **kwargs) nonlocal cache - if cache is not None: - return cache - cache = f(*args, **kwargs) - return cache + if key in cache: + return cache[key] + result = f(*args, **kwargs) + cache[key] = result + return result return cached_fn def fourier_encode(x, max_freq, num_bands = 4): @@ -196,10 +197,10 @@ def __init__( self_attns = nn.ModuleList([]) - for _ in range(self_per_cross_attn): + for block_ind in range(self_per_cross_attn): self_attns.append(nn.ModuleList([ - get_latent_attn(**cache_args), - get_latent_ff(**cache_args) + get_latent_attn(**cache_args, key = block_ind), + get_latent_ff(**cache_args, key = block_ind) ])) self.layers.append(nn.ModuleList([ diff --git a/setup.py b/setup.py index 61ac612..8a93a63 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'perceiver-pytorch', packages = find_packages(), - version = '0.8.0', + version = '0.8.1', license='MIT', description = 'Perceiver - Pytorch', author = 'Phil Wang',