Skip to content

Commit

Permalink
Minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmed-agiza committed Nov 25, 2023
1 parent d205bed commit d9322ab
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 18 deletions.
3 changes: 0 additions & 3 deletions models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,6 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
fused_window_process=fused_window_process)
self.layers.append(layer)

# self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(
self.num_features, num_classes) if num_classes > 0 else nn.Identity()
Expand Down Expand Up @@ -635,8 +634,6 @@ def forward_features(self, x, return_stages=False, flatten_ft=False):
x = layer(x)
if return_stages:
out.append(x)
# if not return_stages:
# x = self.norm(x) # B L C
if flatten_ft:
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
Expand Down
20 changes: 5 additions & 15 deletions models/swin_transformer_mtlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import torch
from torch import Tensor
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from models.lora import MTLoRALinear, MTLoRAQKV
from models.lora import MTLoRALinear

try:
import os
Expand All @@ -25,10 +24,6 @@
WindowProcessReverse = None
print("[Warning] Fused window process have not been installed. Please refer to get_started.md for installation.")

lora_all_window_attenions = True

# Wrapper around nn.linear


class CompatLinear(nn.Linear):
def __init__(self, *args, **kwargs):
Expand All @@ -45,15 +40,15 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay
hidden_features = hidden_features or in_features

if mtlora.FC1_ENABLED:
self.fc1 = MTLoRALinear(in_features, hidden_features, r=mtlora.R_PER_TASK_LIST[layer_idx] if (lora or lora_all_window_attenions) else 0,
self.fc1 = MTLoRALinear(in_features, hidden_features, r=mtlora.R_PER_TASK_LIST[layer_idx] if lora else 0,
lora_shared_scale=mtlora.SHARED_SCALE[layer_idx], lora_task_scale=mtlora.SCALE_PER_TASK_LIST[layer_idx], lora_dropout=mtlora.DROPOUT[layer_idx], tasks=(
tasks if (lora or mtlora.INTERMEDIATE_SPECIALIZATION) else None),
trainable_scale_shared=mtlora.TRAINABLE_SCALE_SHARED, trainable_scale_per_task=mtlora.TRAINABLE_SCALE_PER_TASK, shared_mode=mtlora.SHARED_MODE)
else:
self.fc1 = CompatLinear(in_features, hidden_features)
self.act = act_layer()
if mtlora.FC2_ENABLED:
self.fc2 = MTLoRALinear(hidden_features, out_features, r=mtlora.R_PER_TASK_LIST[layer_idx] if (lora or lora_all_window_attenions) else 0,
self.fc2 = MTLoRALinear(hidden_features, out_features, r=mtlora.R_PER_TASK_LIST[layer_idx] if lora else 0,
lora_shared_scale=mtlora.SHARED_SCALE[layer_idx], lora_task_scale=mtlora.SCALE_PER_TASK_LIST[layer_idx], lora_dropout=mtlora.DROPOUT[layer_idx], tasks=(
tasks if (lora or mtlora.INTERMEDIATE_SPECIALIZATION) else None),
trainable_scale_shared=mtlora.TRAINABLE_SCALE_SHARED, trainable_scale_per_task=mtlora.TRAINABLE_SCALE_PER_TASK, shared_mode=mtlora.SHARED_MODE)
Expand Down Expand Up @@ -157,15 +152,15 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at
self.register_buffer("relative_position_index",
relative_position_index)
if mtlora.QKV_ENABLED:
self.qkv = MTLoRALinear(dim, dim * 3, r=mtlora.R_PER_TASK_LIST[layer_idx] if (lora or lora_all_window_attenions) else 0,
self.qkv = MTLoRALinear(dim, dim * 3, r=mtlora.R_PER_TASK_LIST[layer_idx] if lora else 0,
lora_shared_scale=mtlora.SHARED_SCALE[layer_idx], lora_task_scale=mtlora.SCALE_PER_TASK_LIST[layer_idx], lora_dropout=mtlora.DROPOUT[layer_idx], tasks=None, bias=qkv_bias,
trainable_scale_shared=mtlora.TRAINABLE_SCALE_SHARED, trainable_scale_per_task=mtlora.TRAINABLE_SCALE_PER_TASK, shared_mode=mtlora.SHARED_MODE)
else:
self.qkv = CompatLinear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)

if mtlora.PROJ_ENABLED:
self.proj = MTLoRALinear(dim, dim, r=mtlora.R_PER_TASK_LIST[layer_idx] if (lora or lora_all_window_attenions) else 0,
self.proj = MTLoRALinear(dim, dim, r=mtlora.R_PER_TASK_LIST[layer_idx] if lora else 0,
lora_shared_scale=mtlora.SHARED_SCALE[layer_idx], lora_task_scale=mtlora.SCALE_PER_TASK_LIST[layer_idx], lora_dropout=mtlora.DROPOUT[layer_idx], tasks=(
tasks if (lora or mtlora.INTERMEDIATE_SPECIALIZATION) else None),
trainable_scale_shared=mtlora.TRAINABLE_SCALE_SHARED, trainable_scale_per_task=mtlora.TRAINABLE_SCALE_PER_TASK, shared_mode=mtlora.SHARED_MODE)
Expand Down Expand Up @@ -536,9 +531,6 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size,

def forward(self, x):
for blk in self.blocks:
# if self.use_checkpoint:
# x, tasks_lora = checkpoint.checkpoint(blk, x)
# else:
x, tasks_lora = blk(x)
if self.downsample is not None:
x = self.downsample(x)
Expand Down Expand Up @@ -659,7 +651,6 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
if mtlora is not None:
print("\nMTLoRA params:")
print(mtlora)
print(f"lora_all_window_attenions: {lora_all_window_attenions}\n")

# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
Expand Down Expand Up @@ -705,7 +696,6 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
layer_idx=i_layer)
self.layers.append(layer)

# self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(
self.num_features, num_classes) if num_classes > 0 else nn.Identity()
Expand Down

0 comments on commit d9322ab

Please sign in to comment.