File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -481,7 +481,7 @@ class SwinTransformer(BaseModule):
481
481
embed_dims (int): The feature dimension. Default: 96.
482
482
patch_size (int | tuple[int]): Patch size. Default: 4.
483
483
window_size (int): Window size. Default: 7.
484
- mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
484
+ mlp_ratio (int | float ): Ratio of mlp hidden dim to embedding dim.
485
485
Default: 4.
486
486
depths (tuple[int]): Depths of each Swin Transformer stage.
487
487
Default: (2, 2, 6, 2).
@@ -615,7 +615,7 @@ def __init__(self,
615
615
stage = SwinBlockSequence (
616
616
embed_dims = in_channels ,
617
617
num_heads = num_heads [i ],
618
- feedforward_channels = mlp_ratio * in_channels ,
618
+ feedforward_channels = int ( mlp_ratio * in_channels ) ,
619
619
depth = depths [i ],
620
620
window_size = window_size ,
621
621
qkv_bias = qkv_bias ,
You can’t perform that action at this time.
0 commit comments