Skip to content

Commit

Permalink
Merge branch 'dev' into model-calibration
Browse files Browse the repository at this point in the history
  • Loading branch information
KumoLiu authored Aug 8, 2024
2 parents 0e880a8 + 49a1e34 commit db9daeb
Show file tree
Hide file tree
Showing 10 changed files with 268 additions and 65 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ jobs:
name: Install itk pre-release (Linux only)
run: |
python -m pip install --pre -U itk
find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
- name: Install the dependencies
run: |
python -m pip install --user --upgrade pip wheel
Expand Down
73 changes: 53 additions & 20 deletions monai/networks/blocks/crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.nn as nn

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.utils import optional_import
from monai.utils import optional_import, pytorch_after

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")

Expand All @@ -44,6 +44,7 @@ def __init__(
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
attention_dtype: Optional[torch.dtype] = None,
use_flash_attention: bool = False,
) -> None:
"""
Args:
Expand All @@ -55,13 +56,16 @@ def __init__(
dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
causal: whether to use causal attention.
sequence_length: if causal is True, it is necessary to specify the sequence length.
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
causal (bool, optional): whether to use causal attention.
sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only
"decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional
parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use Pytorch's inbuilt
flash attention for a memory efficient attention mechanism (see
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""

super().__init__()
Expand All @@ -81,6 +85,20 @@ def __init__(
if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
raise ValueError(
"use_flash_attention is only supported for PyTorch versions >= 2.0."
"Upgrade your PyTorch or set the flag to False."
)
if use_flash_attention and save_attn:
raise ValueError(
"save_attn has been set to True, but use_flash_attention is also set"
"to True. save_attn can only be used if use_flash_attention is False"
)

if use_flash_attention and rel_pos_embedding is not None:
raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.context_input_size = context_input_size if context_input_size else hidden_size
Expand All @@ -94,13 +112,15 @@ def __init__(
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.dropout_rate = dropout_rate

self.scale = self.head_dim**-0.5
self.save_attn = save_attn
self.attention_dtype = attention_dtype

self.causal = causal
self.sequence_length = sequence_length
self.use_flash_attention = use_flash_attention

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
Expand Down Expand Up @@ -142,26 +162,39 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
q = q.to(self.attention_dtype)
k = k.to(self.attention_dtype)

q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) #
k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
if self.use_flash_attention:
x = torch.nn.functional.scaled_dot_product_attention(
query=q.transpose(1, 2),
key=k.transpose(1, 2),
value=v.transpose(1, 2),
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
).transpose(
1, 2
) # Back to (b, nh, t, hs)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
# apply relative positional embedding if defined
if self.rel_positional_embedding is not None:
att_mat = self.rel_positional_embedding(x, att_mat, q)

if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
Expand Down
20 changes: 16 additions & 4 deletions monai/networks/blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@

from __future__ import annotations

from typing import Union

import torch.nn as nn

from monai.networks.layers import get_act_layer
from monai.networks.layers.factories import split_args
from monai.utils import look_up_option

SUPPORTED_DROPOUT_MODE = {"vit", "swin"}
SUPPORTED_DROPOUT_MODE = {"vit", "swin", "vista3d"}


class MLPBlock(nn.Module):
Expand All @@ -39,7 +42,7 @@ def __init__(
https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87
"swin" corresponds to one instance as implemented in
https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23
"vista3d" mode does not use dropout.
"""

Expand All @@ -48,15 +51,24 @@ def __init__(
if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
mlp_dim = mlp_dim or hidden_size
self.linear1 = nn.Linear(hidden_size, mlp_dim) if act != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
act_name, _ = split_args(act)
self.linear1 = nn.Linear(hidden_size, mlp_dim) if act_name != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
self.linear2 = nn.Linear(mlp_dim, hidden_size)
self.fn = get_act_layer(act)
self.drop1 = nn.Dropout(dropout_rate)
# Use Union[nn.Dropout, nn.Identity] for type annotations
self.drop1: Union[nn.Dropout, nn.Identity]
self.drop2: Union[nn.Dropout, nn.Identity]

dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE)
if dropout_opt == "vit":
self.drop1 = nn.Dropout(dropout_rate)
self.drop2 = nn.Dropout(dropout_rate)
elif dropout_opt == "swin":
self.drop1 = nn.Dropout(dropout_rate)
self.drop2 = self.drop1
elif dropout_opt == "vista3d":
self.drop1 = nn.Identity()
self.drop2 = nn.Identity()
else:
raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}")

Expand Down
58 changes: 45 additions & 13 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.utils import optional_import
from monai.utils import optional_import, pytorch_after

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")

Expand All @@ -42,6 +43,7 @@ def __init__(
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
attention_dtype: Optional[torch.dtype] = None,
use_flash_attention: bool = False,
) -> None:
"""
Args:
Expand All @@ -59,6 +61,9 @@ def __init__(
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use Pytorch's inbuilt
flash attention for a memory efficient attention mechanism (see
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""

Expand All @@ -82,6 +87,20 @@ def __init__(
if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
raise ValueError(
"use_flash_attention is only supported for PyTorch versions >= 2.0."
"Upgrade your PyTorch or set the flag to False."
)
if use_flash_attention and save_attn:
raise ValueError(
"save_attn has been set to True, but use_flash_attention is also set"
"to True. save_attn can only be used if use_flash_attention is False."
)

if use_flash_attention and rel_pos_embedding is not None:
raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
Expand All @@ -91,12 +110,14 @@ def __init__(
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.dropout_rate = dropout_rate
self.scale = self.dim_head**-0.5
self.save_attn = save_attn
self.att_mat = torch.Tensor()
self.attention_dtype = attention_dtype
self.causal = causal
self.sequence_length = sequence_length
self.use_flash_attention = use_flash_attention

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
Expand Down Expand Up @@ -130,23 +151,34 @@ def forward(self, x):
q = q.to(self.attention_dtype)
k = k.to(self.attention_dtype)

att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
if self.use_flash_attention:
x = F.scaled_dot_product_attention(
query=q.transpose(1, 2),
key=k.transpose(1, 2),
value=v.transpose(1, 2),
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
).transpose(1, 2)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
# apply relative positional embedding if defined
if self.rel_positional_embedding is not None:
att_mat = self.rel_positional_embedding(x, att_mat, q)

if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))
if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
Expand Down
8 changes: 7 additions & 1 deletion monai/networks/blocks/spatialattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SpatialAttentionBlock(nn.Module):
num_channels: number of input channels. Must be divisible by num_head_channels.
num_head_channels: number of channels per head.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

Expand All @@ -44,6 +45,7 @@ def __init__(
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
attention_dtype: Optional[torch.dtype] = None,
use_flash_attention: bool = False,
) -> None:
super().__init__()

Expand All @@ -54,7 +56,11 @@ def __init__(
raise ValueError("num_channels must be divisible by num_head_channels")
num_heads = num_channels // num_head_channels if num_head_channels is not None else 1
self.attn = SABlock(
hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype
hidden_size=num_channels,
num_heads=num_heads,
qkv_bias=True,
attention_dtype=attention_dtype,
use_flash_attention=use_flash_attention,
)

def forward(self, x: torch.Tensor):
Expand Down
13 changes: 11 additions & 2 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,18 @@ def __init__(
causal: bool = False,
sequence_length: int | None = None,
with_cross_attention: bool = False,
use_flash_attention: bool = False,
) -> None:
"""
Args:
hidden_size (int): dimension of hidden layer.
mlp_dim (int): dimension of feedforward layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""

Expand All @@ -66,13 +69,19 @@ def __init__(
save_attn=save_attn,
causal=causal,
sequence_length=sequence_length,
use_flash_attention=use_flash_attention,
)
self.norm2 = nn.LayerNorm(hidden_size)
self.with_cross_attention = with_cross_attention

self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)

def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
Expand Down
Loading

0 comments on commit db9daeb

Please sign in to comment.