Skip to content

Commit

Permalink
Bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
fkodom committed Aug 3, 2023
1 parent 0208847 commit f06f7d9
Showing 1 changed file with 29 additions and 19 deletions.
48 changes: 29 additions & 19 deletions dilated_attention_pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,25 @@ def _reset_parameters(self):
nn.init.xavier_normal_(self.linear2.weight, gain=self.gamma_init)
nn.init.constant_(self.linear2.bias, 0)

def forward(self, src: Tensor, is_causal: bool = False) -> Tensor:
x = src

# Self-attention block
def _self_attention_block(self, x: Tensor, is_causal: bool = False) -> Tensor:
x = self.norm1(x)
x, _ = self.self_attn(x, x, x, is_causal=is_causal)
x = self.dropout(x)
return x

# Feedforward block
def _feedforward_block(self, x: Tensor) -> Tensor:
x = self.norm2(x)
x = self.activation(self.linear1(x))
x = self.dropout(x)
x = self.norm3(x)
x = self.linear2(x)
x = self.dropout(x)
return x

def forward(self, src: Tensor, is_causal: bool = False) -> Tensor:
x = src
x = x + self._self_attention_block(x, is_causal=is_causal)
x = x + self._feedforward_block(x)
return x


Expand Down Expand Up @@ -173,31 +176,38 @@ def _reset_parameters(self):
nn.init.xavier_normal_(self.linear2.weight, gain=self.gamma_init)
nn.init.constant_(self.linear2.bias, 0)

def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_is_causal: bool = False,
memory_is_causal: bool = False,
) -> Tensor:
x = tgt

# Self-attention block
def _self_attention_block(self, x: Tensor, is_causal: bool = False) -> Tensor:
x = self.norm1(x)
x, _ = self.self_attn(x, x, x, is_causal=tgt_is_causal)
x, _ = self.self_attn(x, x, x, is_causal=is_causal)
x = self.dropout(x)
return x

# Multihead attention block
def _multihead_attention_block(
self, x: Tensor, memory: Tensor, is_causal: bool = False
) -> Tensor:
x = self.norm2(x)
x, _ = self.multihead_attn(x, memory, memory, is_causal=memory_is_causal)
x, _ = self.multihead_attn(x, memory, memory, is_causal=is_causal)
x = self.dropout(x)
return x

# Feedforward block
def _feedforward_block(self, x: Tensor) -> Tensor:
x = self.norm3(x)
x = self.activation(self.linear1(x))
x = self.dropout(x)
x = self.norm4(x)
x = self.linear2(x)
x = self.dropout(x)
return x

def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_is_causal: bool = False,
memory_is_causal: bool = False,
) -> Tensor:
x = tgt
x = x + self._self_attention_block(x, is_causal=tgt_is_causal)
x = x + self._multihead_attention_block(x, memory, is_causal=memory_is_causal)
x = x + self._feedforward_block(x)
return x

0 comments on commit f06f7d9

Please sign in to comment.