@@ -52,15 +52,6 @@ def forward(self, x):
5252 x , dim = (1 if self .channel_first else - 1 )) * self .scale * self .gamma .to (x ) + (self .bias .to (x ) if self .bias is not None else 0 )
5353
5454
55- class Upsample (nn .Upsample ):
56-
57- def forward (self , x ):
58- """
59- Fix bfloat16 support for nearest neighbor interpolation.
60- """
61- return super ().forward (x .float ()).type_as (x )
62-
63-
6455class Resample (nn .Module ):
6556
6657 def __init__ (self , dim , mode ):
@@ -73,11 +64,11 @@ def __init__(self, dim, mode):
7364 # layers
7465 if mode == 'upsample2d' :
7566 self .resample = nn .Sequential (
76- Upsample (scale_factor = (2. , 2. ), mode = 'nearest-exact' ),
67+ nn . Upsample (scale_factor = (2. , 2. ), mode = 'nearest-exact' ),
7768 ops .Conv2d (dim , dim // 2 , 3 , padding = 1 ))
7869 elif mode == 'upsample3d' :
7970 self .resample = nn .Sequential (
80- Upsample (scale_factor = (2. , 2. ), mode = 'nearest-exact' ),
71+ nn . Upsample (scale_factor = (2. , 2. ), mode = 'nearest-exact' ),
8172 ops .Conv2d (dim , dim // 2 , 3 , padding = 1 ))
8273 self .time_conv = CausalConv3d (
8374 dim , dim * 2 , (3 , 1 , 1 ), padding = (1 , 0 , 0 ))
0 commit comments